File size: 2,759 Bytes
8166d49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
from typing import List, Tuple
import logging
from .base_analyzer import BaseAnalyzer

logger = logging.getLogger(__name__)

class NERAnalyzer(BaseAnalyzer):
    def __init__(self):
        self.model_name = "dominguesm/ner-legal-bert-base-cased-ptbr"
        logger.info(f"Carregando o modelo NER: {self.model_name}")
        self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        logger.info("Modelo NER e tokenizador carregados com sucesso")

    def extract_entities(self, text: str) -> List[Tuple[str, str]]:
        logger.debug("Iniciando extração de entidades com NER")
        inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt")
        tokens = inputs.tokens()
        
        with torch.no_grad():
            outputs = self.model(**inputs).logits
            predictions = torch.argmax(outputs, dim=2)
        
        entities = []
        for token, prediction in zip(tokens, predictions[0].numpy()):
            entity_label = self.model.config.id2label[prediction]
            if entity_label != "O":
                entities.append((token, entity_label))
        
        return entities

    def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]:
        representatives = []
        current_person = ""
        current_organization = ""
        
        for token, label in entities:
            if label in ["B-PESSOA", "I-PESSOA"]:
                current_person += token.replace("##", "")
            else:
                if current_person:
                    representatives.append(current_person)
                    current_person = ""
                    
            if label in ["B-ORGANIZACAO", "I-ORGANIZACAO"]:
                current_organization += token.replace("##", "")
            else:
                if current_organization:
                    representatives.append(current_organization)
                    current_organization = ""
        
        if current_person:
            representatives.append(current_person)
        if current_organization:
            representatives.append(current_organization)
        
        return representatives

    def analyze(self, text: str) -> List[str]:
        entities = self.extract_entities(text)
        return self.extract_representatives(entities)

    def format_output(self, representatives: List[str]) -> str:
        output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n"
        output += "REPRESENTANTES IDENTIFICADOS:\n"
        for rep in representatives:
            output += f"- {rep}\n"
        return output