thlinhares commited on
Commit
8864085
·
verified ·
1 Parent(s): 4d79540

Update analyzers/ner_analyzer.py

Browse files
Files changed (1) hide show
  1. analyzers/ner_analyzer.py +65 -24
analyzers/ner_analyzer.py CHANGED
@@ -4,65 +4,106 @@ import torch
4
  from typing import List, Tuple
5
  import logging
6
  from .base_analyzer import BaseAnalyzer
7
- from huggingface_hub import hf_api
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class NERAnalyzer(BaseAnalyzer):
12
  def __init__(self):
13
- self.model_name = "pierreguillou/ner-bert-base-pt" # Modelo NER para português
14
  logger.info(f"Carregando o modelo NER: {self.model_name}")
15
 
16
- # Passando o token de autenticação ao carregar o modelo
17
- self.token = os.getenv("token_huggingface")
18
- self.model = AutoModelForTokenClassification.from_pretrained(self.model_name, use_auth_token=self.token)
19
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_auth_token=self.token)
 
 
 
20
  logger.info("Modelo NER e tokenizador carregados com sucesso")
21
-
22
  def extract_entities(self, text: str) -> List[Tuple[str, str]]:
23
  logger.debug("Iniciando extração de entidades com NER")
24
- inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt")
25
- tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Converter ids de volta para tokens
 
 
 
 
 
 
 
 
 
 
 
 
26
  with torch.no_grad():
27
- outputs = self.model(**inputs).logits
28
- predictions = torch.argmax(outputs, dim=2)
29
-
30
  entities = []
31
  for token, prediction in zip(tokens, predictions[0].numpy()):
32
  entity_label = self.model.config.id2label[prediction]
33
- if entity_label != "O": # Ignorar tokens não relacionados a entidades
34
- entities.append((token, entity_label))
35
-
 
 
 
 
 
 
 
36
  logger.info(f"Entidades extraídas: {entities}")
37
  return entities
38
 
39
  def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]:
 
 
 
40
  representatives = []
41
  current_entity = []
42
  current_label = None
43
-
44
  for token, label in entities:
45
- if label == current_label:
 
 
 
 
 
 
46
  current_entity.append(token)
47
  else:
48
  if current_entity:
49
- representatives.append(" ".join(current_entity))
50
  current_entity = [token]
51
  current_label = label
52
-
 
53
  if current_entity:
54
- representatives.append(" ".join(current_entity))
55
-
 
 
 
 
56
  logger.info(f"Representantes extraídos: {representatives}")
57
  return representatives
58
-
59
  def analyze(self, text: str) -> List[str]:
60
  entities = self.extract_entities(text)
61
  return self.extract_representatives(entities)
62
 
63
  def format_output(self, representatives: List[str]) -> str:
64
  output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n"
65
- output += "REPRESENTANTES IDENTIFICADOS:\n"
 
 
 
 
 
66
  for rep in representatives:
67
  output += f"- {rep}\n"
68
- return output
 
 
4
  from typing import List, Tuple
5
  import logging
6
  from .base_analyzer import BaseAnalyzer
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
  class NERAnalyzer(BaseAnalyzer):
11
  def __init__(self):
12
+ self.model_name = "jpbahiaz/bert-base-portuguese-ner" # Modelo NER mais leve para português
13
  logger.info(f"Carregando o modelo NER: {self.model_name}")
14
 
15
+ # Carregando o modelo e tokenizer sem necessidade de token de autenticação
16
+ self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
17
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
18
+
19
+ # Definindo as labels que queremos extrair (pessoas e organizações)
20
+ self.target_labels = ['B-PESSOA', 'I-PESSOA', 'B-ORGANIZACAO', 'I-ORGANIZACAO']
21
+
22
  logger.info("Modelo NER e tokenizador carregados com sucesso")
23
+
24
  def extract_entities(self, text: str) -> List[Tuple[str, str]]:
25
  logger.debug("Iniciando extração de entidades com NER")
26
+
27
+ # Pré-processamento do texto
28
+ inputs = self.tokenizer(
29
+ text,
30
+ max_length=512,
31
+ truncation=True,
32
+ return_tensors="pt",
33
+ padding=True
34
+ )
35
+
36
+ # Obtendo tokens
37
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
38
+
39
+ # Fazendo a predição
40
  with torch.no_grad():
41
+ outputs = self.model(**inputs)
42
+ predictions = torch.argmax(outputs.logits, dim=2)
43
+
44
  entities = []
45
  for token, prediction in zip(tokens, predictions[0].numpy()):
46
  entity_label = self.model.config.id2label[prediction]
47
+
48
+ # Filtrando apenas pessoas e organizações
49
+ if entity_label in self.target_labels:
50
+ # Removendo prefixos especiais do tokenizer
51
+ if token.startswith("##"):
52
+ token = token[2:]
53
+ # Ignorando tokens especiais
54
+ if token not in ["[CLS]", "[SEP]", "[PAD]"]:
55
+ entities.append((token, entity_label))
56
+
57
  logger.info(f"Entidades extraídas: {entities}")
58
  return entities
59
 
60
  def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]:
61
+ if not entities:
62
+ return []
63
+
64
  representatives = []
65
  current_entity = []
66
  current_label = None
67
+
68
  for token, label in entities:
69
+ # Verificando se é continuação da mesma entidade
70
+ is_same_entity = (
71
+ (label.startswith('B-') and current_label and current_label.endswith(label[2:])) or
72
+ (label.startswith('I-') and current_label and current_label.endswith(label[2:]))
73
+ )
74
+
75
+ if is_same_entity:
76
  current_entity.append(token)
77
  else:
78
  if current_entity:
79
+ representatives.append("".join(current_entity).replace(" ##", ""))
80
  current_entity = [token]
81
  current_label = label
82
+
83
+ # Adicionando a última entidade
84
  if current_entity:
85
+ representatives.append("".join(current_entity).replace(" ##", ""))
86
+
87
+ # Removendo duplicatas e limpando
88
+ representatives = list(set(representatives))
89
+ representatives = [rep.strip() for rep in representatives if len(rep.strip()) > 1]
90
+
91
  logger.info(f"Representantes extraídos: {representatives}")
92
  return representatives
93
+
94
  def analyze(self, text: str) -> List[str]:
95
  entities = self.extract_entities(text)
96
  return self.extract_representatives(entities)
97
 
98
  def format_output(self, representatives: List[str]) -> str:
99
  output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n"
100
+
101
+ if not representatives:
102
+ output += "Nenhum representante ou empresa identificado.\n"
103
+ return output
104
+
105
+ output += "REPRESENTANTES E EMPRESAS IDENTIFICADOS:\n"
106
  for rep in representatives:
107
  output += f"- {rep}\n"
108
+
109
+ return output