thlinhares commited on
Commit
8166d49
·
verified ·
1 Parent(s): ebd1fea

Create ner_analyzer.py

Browse files
Files changed (1) hide show
  1. ner_analyzer.py +70 -0
ner_analyzer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
2
+ import torch
3
+ from typing import List, Tuple
4
+ import logging
5
+ from .base_analyzer import BaseAnalyzer
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class NERAnalyzer(BaseAnalyzer):
10
+ def __init__(self):
11
+ self.model_name = "dominguesm/ner-legal-bert-base-cased-ptbr"
12
+ logger.info(f"Carregando o modelo NER: {self.model_name}")
13
+ self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
14
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
15
+ logger.info("Modelo NER e tokenizador carregados com sucesso")
16
+
17
+ def extract_entities(self, text: str) -> List[Tuple[str, str]]:
18
+ logger.debug("Iniciando extração de entidades com NER")
19
+ inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt")
20
+ tokens = inputs.tokens()
21
+
22
+ with torch.no_grad():
23
+ outputs = self.model(**inputs).logits
24
+ predictions = torch.argmax(outputs, dim=2)
25
+
26
+ entities = []
27
+ for token, prediction in zip(tokens, predictions[0].numpy()):
28
+ entity_label = self.model.config.id2label[prediction]
29
+ if entity_label != "O":
30
+ entities.append((token, entity_label))
31
+
32
+ return entities
33
+
34
+ def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]:
35
+ representatives = []
36
+ current_person = ""
37
+ current_organization = ""
38
+
39
+ for token, label in entities:
40
+ if label in ["B-PESSOA", "I-PESSOA"]:
41
+ current_person += token.replace("##", "")
42
+ else:
43
+ if current_person:
44
+ representatives.append(current_person)
45
+ current_person = ""
46
+
47
+ if label in ["B-ORGANIZACAO", "I-ORGANIZACAO"]:
48
+ current_organization += token.replace("##", "")
49
+ else:
50
+ if current_organization:
51
+ representatives.append(current_organization)
52
+ current_organization = ""
53
+
54
+ if current_person:
55
+ representatives.append(current_person)
56
+ if current_organization:
57
+ representatives.append(current_organization)
58
+
59
+ return representatives
60
+
61
+ def analyze(self, text: str) -> List[str]:
62
+ entities = self.extract_entities(text)
63
+ return self.extract_representatives(entities)
64
+
65
+ def format_output(self, representatives: List[str]) -> str:
66
+ output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n"
67
+ output += "REPRESENTANTES IDENTIFICADOS:\n"
68
+ for rep in representatives:
69
+ output += f"- {rep}\n"
70
+ return output