|
|
|
""" |
|
Named Entity Recognition (NER) using Transformers |
|
Extracts entities like PERSON, LOCATION, ORGANIZATION from text |
|
""" |
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification |
|
import argparse |
|
from typing import List, Dict, Any |
|
import json |
|
import os |
|
import logging |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format='%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class TransformerNER: |
|
|
|
|
|
MODELS = { |
|
"dslim-bert": "dslim/bert-base-NER", |
|
"dbmdz-bert": "dbmdz/bert-large-cased-finetuned-conll03-english", |
|
"xlm-roberta": "xlm-roberta-large-finetuned-conll03-english", |
|
"distilbert": "distilbert-base-cased-distilled-squad" |
|
} |
|
|
|
def __init__(self, model_name: str = "dslim/bert-base-NER", aggregation_strategy: str = "simple"): |
|
""" |
|
Initialize NER pipeline with specified model |
|
Default model: dslim/bert-base-NER (lightweight BERT model fine-tuned for NER) |
|
""" |
|
self.logger = logging.getLogger(__name__) |
|
self.current_model_name = model_name |
|
self.cache_dir = os.path.join(os.path.dirname(__file__), "model_cache") |
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
self._load_model(model_name, aggregation_strategy) |
|
|
|
def _load_model(self, model_name: str, aggregation_strategy: str = "simple"): |
|
"""Load or reload model with given parameters""" |
|
|
|
if model_name in self.MODELS: |
|
resolved_name = self.MODELS[model_name] |
|
else: |
|
resolved_name = model_name |
|
|
|
self.current_model_name = model_name |
|
self.aggregation_strategy = aggregation_strategy |
|
|
|
self.logger.info(f"Loading model: {resolved_name}") |
|
self.logger.info(f"Cache directory: {self.cache_dir}") |
|
self.logger.info(f"Aggregation strategy: {aggregation_strategy}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(resolved_name, cache_dir=self.cache_dir) |
|
self.model = AutoModelForTokenClassification.from_pretrained(resolved_name, cache_dir=self.cache_dir) |
|
self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, aggregation_strategy=aggregation_strategy) |
|
self.logger.info("Model loaded successfully!") |
|
|
|
def switch_model(self, model_name: str, aggregation_strategy: str = None): |
|
"""Switch to a different model dynamically""" |
|
if aggregation_strategy is None: |
|
aggregation_strategy = self.aggregation_strategy |
|
|
|
try: |
|
self._load_model(model_name, aggregation_strategy) |
|
return True |
|
except Exception as e: |
|
self.logger.error(f"Failed to load model '{model_name}': {e}") |
|
return False |
|
|
|
def change_aggregation(self, aggregation_strategy: str): |
|
"""Change aggregation strategy for current model""" |
|
try: |
|
self._load_model(self.current_model_name, aggregation_strategy) |
|
return True |
|
except Exception as e: |
|
self.logger.error(f"Failed to change aggregation to '{aggregation_strategy}': {e}") |
|
return False |
|
|
|
def _post_process_entities(self, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
""" |
|
Post-process entities to fix common boundary and classification issues |
|
""" |
|
corrected = [] |
|
|
|
for entity in entities: |
|
text = entity["text"].strip() |
|
entity_type = entity["entity"] |
|
|
|
|
|
if not text: |
|
continue |
|
|
|
|
|
corrected_entity = entity.copy() |
|
|
|
|
|
if entity_type == "ORG" and len(text.split()) == 1: |
|
|
|
if any(text.lower().endswith(suffix) for suffix in ['i', 'a', 'o']) or text.istitle(): |
|
corrected_entity["entity"] = "PER" |
|
self.logger.debug(f"Fixed: '{text}' ORG -> PER") |
|
|
|
|
|
countries = ['India', 'China', 'USA', 'UK', 'Germany', 'France', 'Japan'] |
|
if text in countries and entity_type != "LOC": |
|
corrected_entity["entity"] = "LOC" |
|
self.logger.debug(f"Fixed: '{text}' {entity_type} -> LOC") |
|
|
|
|
|
words = text.split() |
|
if len(words) >= 2 and entity_type == "ORG": |
|
|
|
if words[0].istitle() and words[1].lower() in ['launches', 'announces', 'says', 'opens', 'creates', 'launch']: |
|
|
|
corrected_entity["text"] = words[0] |
|
corrected_entity["entity"] = "PER" |
|
corrected_entity["end"] = corrected_entity["start"] + len(words[0]) |
|
self.logger.info(f"Split entity: '{text}' -> PER: '{words[0]}'") |
|
|
|
|
|
tech_terms = ['electric', 'suv', 'car', 'vehicle', 'app', 'software', 'ai', 'robot', 'global'] |
|
if any(term in text.lower() for term in tech_terms): |
|
if entity_type != "MISC": |
|
corrected_entity["entity"] = "MISC" |
|
self.logger.info(f"Fixed: '{text}' {entity_type} -> MISC") |
|
else: |
|
self.logger.debug(f"Already MISC: '{text}'") |
|
|
|
corrected.append(corrected_entity) |
|
|
|
return corrected |
|
|
|
def extract_entities(self, text: str, return_both: bool = False) -> Dict[str, List[Dict[str, Any]]]: |
|
""" |
|
Extract named entities from text |
|
Returns list of entities with their labels, scores, and positions |
|
|
|
If return_both=True, returns dict with 'cleaned' and 'corrected' keys |
|
If return_both=False, returns just the corrected entities (backward compatibility) |
|
""" |
|
entities = self.nlp(text) |
|
|
|
|
|
cleaned_entities = [] |
|
for entity in entities: |
|
cleaned_entities.append({ |
|
"entity": entity["entity_group"], |
|
"text": entity["word"], |
|
"score": round(entity["score"], 4), |
|
"start": entity["start"], |
|
"end": entity["end"] |
|
}) |
|
|
|
|
|
corrected_entities = self._post_process_entities(cleaned_entities) |
|
|
|
if return_both: |
|
return { |
|
"cleaned": cleaned_entities, |
|
"corrected": corrected_entities |
|
} |
|
else: |
|
return corrected_entities |
|
|
|
def extract_entities_debug(self, text: str) -> Dict[str, List[Dict[str, Any]]]: |
|
""" |
|
Extract entities and return both cleaned and corrected versions for debugging |
|
""" |
|
return self.extract_entities(text, return_both=True) |
|
|
|
def extract_entities_by_type(self, text: str) -> Dict[str, List[str]]: |
|
""" |
|
Extract entities grouped by type |
|
Returns dictionary with entity types as keys |
|
""" |
|
entities = self.extract_entities(text) |
|
|
|
grouped = {} |
|
for entity in entities: |
|
entity_type = entity["entity"] |
|
if entity_type not in grouped: |
|
grouped[entity_type] = [] |
|
if entity["text"] not in grouped[entity_type]: |
|
grouped[entity_type].append(entity["text"]) |
|
|
|
return grouped |
|
|
|
def format_output(self, entities: List[Dict[str, Any]], text: str) -> str: |
|
""" |
|
Format entities for display with context |
|
""" |
|
output = [] |
|
output.append("=" * 60) |
|
output.append("NAMED ENTITY RECOGNITION RESULTS") |
|
output.append("=" * 60) |
|
output.append(f"\nOriginal Text:\n{text}\n") |
|
output.append("-" * 40) |
|
output.append("Entities Found:") |
|
output.append("-" * 40) |
|
|
|
if not entities: |
|
output.append("No entities found.") |
|
else: |
|
for entity in entities: |
|
output.append(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
|
|
|
return "\n".join(output) |
|
|
|
def format_debug_output(self, debug_results: Dict[str, List[Dict[str, Any]]], text: str) -> str: |
|
""" |
|
Format debug output showing both cleaned and corrected entities |
|
""" |
|
output = [] |
|
output.append("=" * 70) |
|
output.append("NER DEBUG: BEFORE & AFTER POST-PROCESSING") |
|
output.append("=" * 70) |
|
output.append(f"\nOriginal Text:\n{text}\n") |
|
|
|
cleaned = debug_results["cleaned"] |
|
corrected = debug_results["corrected"] |
|
|
|
|
|
output.append("🔍 BEFORE Post-Processing (Raw Model Output):") |
|
output.append("-" * 50) |
|
if not cleaned: |
|
output.append("No entities found by model.") |
|
else: |
|
for entity in cleaned: |
|
output.append(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
|
|
|
output.append("") |
|
|
|
|
|
output.append("✨ AFTER Post-Processing (Corrected):") |
|
output.append("-" * 50) |
|
if not corrected: |
|
output.append("No entities after correction.") |
|
else: |
|
for entity in corrected: |
|
output.append(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
|
|
|
|
|
output.append("") |
|
output.append("📝 Changes Made:") |
|
output.append("-" * 25) |
|
|
|
changes_found = False |
|
|
|
|
|
cleaned_lookup = {(e['text'], e['entity']) for e in cleaned} |
|
corrected_lookup = {(e['text'], e['entity']) for e in corrected} |
|
|
|
|
|
for corrected_entity in corrected: |
|
corrected_key = (corrected_entity['text'], corrected_entity['entity']) |
|
|
|
|
|
original_entity = None |
|
for cleaned_entity in cleaned: |
|
if (cleaned_entity['text'] == corrected_entity['text'] and |
|
cleaned_entity['entity'] != corrected_entity['entity']): |
|
original_entity = cleaned_entity |
|
break |
|
|
|
if original_entity: |
|
output.append(f" Fixed: '{original_entity['text']}' {original_entity['entity']} → {corrected_entity['entity']}") |
|
changes_found = True |
|
|
|
|
|
for corrected_entity in corrected: |
|
found_exact_match = False |
|
for cleaned_entity in cleaned: |
|
if (cleaned_entity['text'] == corrected_entity['text'] and |
|
cleaned_entity['entity'] == corrected_entity['entity']): |
|
found_exact_match = True |
|
break |
|
|
|
if not found_exact_match: |
|
|
|
for cleaned_entity in cleaned: |
|
if (corrected_entity['text'] in cleaned_entity['text'] and |
|
corrected_entity['text'] != cleaned_entity['text']): |
|
output.append(f" Split: '{cleaned_entity['text']}' → '{corrected_entity['text']}'") |
|
changes_found = True |
|
break |
|
|
|
if not changes_found: |
|
output.append(" No changes made by post-processing.") |
|
|
|
return "\n".join(output) |
|
|
|
|
|
def interactive_mode(ner: TransformerNER): |
|
""" |
|
Interactive mode that keeps the model loaded and processes multiple texts |
|
""" |
|
print("\n" + "=" * 60) |
|
print("INTERACTIVE NER MODE") |
|
print("=" * 60) |
|
print("Enter text to analyze (or 'quit' to exit)") |
|
print("Commands: 'help' for full list, 'model <name>' to switch models") |
|
print("=" * 60) |
|
|
|
grouped_mode = False |
|
json_mode = False |
|
debug_mode = False |
|
|
|
def show_help(): |
|
print("\n" + "=" * 50) |
|
print("INTERACTIVE COMMANDS") |
|
print("=" * 50) |
|
print("Output Modes:") |
|
print(f" grouped - Toggle grouped output (currently: {'ON' if grouped_mode else 'OFF'})") |
|
print(f" json - Toggle JSON output (currently: {'ON' if json_mode else 'OFF'})") |
|
print(f" debug - Toggle debug mode - show before/after post-processing (currently: {'ON' if debug_mode else 'OFF'})") |
|
print("\nModel Management:") |
|
print(" model <name> - Switch to model (e.g., 'model dbmdz-bert')") |
|
print(" models - List available model shortcuts") |
|
print(" agg <strat> - Change aggregation (simple/first/average/max)") |
|
print("\nFile Operations:") |
|
print(" file <path> - Analyze text from file") |
|
print("\nInformation:") |
|
print(" info - Show current configuration") |
|
print(" help - Show this help") |
|
print(" quit - Exit interactive mode") |
|
print("=" * 50) |
|
|
|
def show_models(): |
|
print("\nAvailable model shortcuts:") |
|
print("-" * 50) |
|
for shortcut, full_name in TransformerNER.MODELS.items(): |
|
current = " (current)" if shortcut == ner.current_model_name or full_name == ner.current_model_name else "" |
|
print(f" {shortcut:<15} -> {full_name}{current}") |
|
print(f"\nUsage: 'model <shortcut>' (e.g., 'model dbmdz-bert')") |
|
print(f"Aggregation strategies: {['simple', 'first', 'average', 'max']}") |
|
print(f"Usage: 'agg <strategy>' (e.g., 'agg first')") |
|
|
|
def show_info(): |
|
resolved_name = ner.MODELS.get(ner.current_model_name, ner.current_model_name) |
|
print(f"\nCurrent Configuration:") |
|
print(f" Model: {ner.current_model_name}") |
|
print(f" Full name: {resolved_name}") |
|
print(f" Aggregation: {ner.aggregation_strategy}") |
|
print(f" Grouped mode: {'ON' if grouped_mode else 'OFF'}") |
|
print(f" JSON mode: {'ON' if json_mode else 'OFF'}") |
|
print(f" Debug mode: {'ON' if debug_mode else 'OFF'}") |
|
print(f" Cache dir: {ner.cache_dir}") |
|
|
|
def switch_model(model_name: str): |
|
print(f"Switching to model: {model_name}") |
|
if ner.switch_model(model_name): |
|
print(f"✅ Successfully switched to {model_name}") |
|
return True |
|
else: |
|
print(f"❌ Failed to switch to {model_name}") |
|
return False |
|
|
|
def change_aggregation(strategy: str): |
|
valid_strategies = ["simple", "first", "average", "max"] |
|
if strategy not in valid_strategies: |
|
print(f"❌ Invalid aggregation strategy. Valid options: {valid_strategies}") |
|
return False |
|
|
|
print(f"Changing aggregation to: {strategy}") |
|
if ner.change_aggregation(strategy): |
|
print(f"✅ Successfully changed aggregation to {strategy}") |
|
return True |
|
else: |
|
print(f"❌ Failed to change aggregation to {strategy}") |
|
return False |
|
|
|
def process_file(file_path: str): |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
file_text = f.read() |
|
print(f"📁 Processing file: {file_path}") |
|
return file_text.strip() |
|
except Exception as e: |
|
print(f"❌ Error reading file '{file_path}': {e}") |
|
return None |
|
|
|
while True: |
|
try: |
|
print("\n> ", end="", flush=True) |
|
user_input = input().strip() |
|
|
|
if not user_input: |
|
continue |
|
|
|
|
|
parts = user_input.split(None, 1) |
|
command = parts[0].lower() |
|
args = parts[1] if len(parts) > 1 else "" |
|
|
|
|
|
if command in ['quit', 'exit', 'q']: |
|
print("Goodbye!") |
|
break |
|
|
|
|
|
elif command == 'grouped': |
|
grouped_mode = not grouped_mode |
|
print(f"Grouped mode: {'ON' if grouped_mode else 'OFF'}") |
|
continue |
|
|
|
elif command == 'json': |
|
json_mode = not json_mode |
|
print(f"JSON mode: {'ON' if json_mode else 'OFF'}") |
|
continue |
|
|
|
elif command == 'debug': |
|
debug_mode = not debug_mode |
|
print(f"Debug mode: {'ON' if debug_mode else 'OFF'}") |
|
continue |
|
|
|
|
|
elif command in ['models', 'list-models']: |
|
show_models() |
|
continue |
|
|
|
elif command == 'info': |
|
show_info() |
|
continue |
|
|
|
elif command == 'help': |
|
show_help() |
|
continue |
|
|
|
|
|
elif command == 'model': |
|
if not args: |
|
print("❌ Please specify a model name. Use 'models' to see available options.") |
|
continue |
|
switch_model(args) |
|
continue |
|
|
|
elif command in ['agg', 'aggregation']: |
|
if not args: |
|
print("❌ Please specify an aggregation strategy: simple, first, average, max") |
|
continue |
|
change_aggregation(args) |
|
continue |
|
|
|
|
|
elif command == 'file': |
|
if not args: |
|
print("❌ Please specify a file path.") |
|
continue |
|
file_content = process_file(args) |
|
if file_content: |
|
user_input = file_content |
|
else: |
|
continue |
|
|
|
|
|
text = user_input if command != 'file' else file_content |
|
|
|
|
|
if debug_mode: |
|
|
|
debug_results = ner.extract_entities_debug(text) |
|
debug_output = ner.format_debug_output(debug_results, text) |
|
print(debug_output) |
|
else: |
|
|
|
if grouped_mode: |
|
entities = ner.extract_entities_by_type(text) |
|
else: |
|
entities = ner.extract_entities(text) |
|
|
|
|
|
if json_mode: |
|
print(json.dumps(entities, indent=2)) |
|
elif grouped_mode: |
|
print("\nEntities by type:") |
|
print("-" * 30) |
|
if not entities: |
|
print("No entities found.") |
|
else: |
|
for entity_type, entity_list in entities.items(): |
|
print(f"{entity_type}: {', '.join(entity_list)}") |
|
else: |
|
if not entities: |
|
print("No entities found.") |
|
else: |
|
print("\nEntities found:") |
|
print("-" * 20) |
|
for entity in entities: |
|
print(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
|
|
|
except KeyboardInterrupt: |
|
print("\n\nGoodbye!") |
|
break |
|
except EOFError: |
|
print("\nGoodbye!") |
|
break |
|
except Exception as e: |
|
logger.error(f"Error processing text: {e}") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Extract named entities from text using Transformers") |
|
parser.add_argument("--text", type=str, help="Text to analyze") |
|
parser.add_argument("--file", type=str, help="File containing text to analyze") |
|
parser.add_argument("--model", type=str, default="dslim/bert-base-NER", |
|
help="HuggingFace model to use. Shortcuts: dslim-bert, dbmdz-bert, xlm-roberta") |
|
parser.add_argument("--aggregation", type=str, default="simple", |
|
choices=["simple", "first", "average", "max"], |
|
help="Aggregation strategy for subword tokens (default: simple)") |
|
parser.add_argument("--json", action="store_true", help="Output as JSON") |
|
parser.add_argument("--grouped", action="store_true", help="Group entities by type") |
|
parser.add_argument("--interactive", "-i", action="store_true", help="Start interactive mode") |
|
parser.add_argument("--list-models", action="store_true", help="List available model shortcuts") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.list_models: |
|
print("\nAvailable model shortcuts:") |
|
print("-" * 40) |
|
for shortcut, full_name in TransformerNER.MODELS.items(): |
|
print(f" {shortcut:<15} -> {full_name}") |
|
print(f"\nDefault aggregation strategies: {['simple', 'first', 'average', 'max']}") |
|
return |
|
|
|
|
|
ner = TransformerNER(model_name=args.model, aggregation_strategy=args.aggregation) |
|
|
|
|
|
if args.interactive: |
|
interactive_mode(ner) |
|
return |
|
|
|
|
|
if args.file: |
|
with open(args.file, 'r') as f: |
|
text = f.read() |
|
elif args.text: |
|
text = args.text |
|
else: |
|
|
|
interactive_mode(ner) |
|
return |
|
|
|
if not text.strip(): |
|
logging.error("No text provided") |
|
return |
|
|
|
|
|
if args.grouped: |
|
entities = ner.extract_entities_by_type(text) |
|
else: |
|
entities = ner.extract_entities(text) |
|
|
|
|
|
if args.json: |
|
print(json.dumps(entities, indent=2)) |
|
elif args.grouped: |
|
print("\n" + "=" * 60) |
|
print("ENTITIES GROUPED BY TYPE") |
|
print("=" * 60) |
|
for entity_type, entity_list in entities.items(): |
|
print(f"\n{entity_type}:") |
|
for item in entity_list: |
|
print(f" • {item}") |
|
else: |
|
formatted = ner.format_output(entities, text) |
|
print(formatted) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
example_sentences = [ |
|
"Apple Inc. was founded by Steve Jobs in Cupertino, California.", |
|
"Barack Obama was the 44th President of the United States.", |
|
"The Eiffel Tower in Paris attracts millions of tourists each year.", |
|
"Google's CEO Sundar Pichai announced new AI features at the conference in San Francisco.", |
|
"Microsoft and OpenAI partnered to develop ChatGPT in Seattle." |
|
] |
|
|
|
|
|
import sys |
|
if len(sys.argv) == 1: |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
|
|
|
logging.info("Running demo with example sentences...\n") |
|
ner = TransformerNER() |
|
|
|
for sentence in example_sentences: |
|
print("\n" + "="*60) |
|
print(f"Input: {sentence}") |
|
print("-"*40) |
|
entities = ner.extract_entities_by_type(sentence) |
|
for entity_type, items in entities.items(): |
|
print(f"{entity_type}: {', '.join(items)}") |
|
|
|
print("\n" + "="*60) |
|
print("\nTo analyze your own text, use:") |
|
print(" python ner_transformer.py --text 'Your text here'") |
|
print(" python ner_transformer.py --file input.txt") |
|
print(" python ner_transformer.py --json --grouped") |
|
else: |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
|
main() |
|
|