kltn20133118's picture
Upload 337 files
dbaa71b verified
import logging
from typing import Any, Dict, Generator, List, Optional, Tuple, Iterator
from pydantic import PrivateAttr
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
Pipeline,
pipeline,
)
import spacy
from spacy.language import Language
from spacy.tokens.doc import Doc
from obsei.analyzer.base_analyzer import (
BaseAnalyzer,
BaseAnalyzerConfig,
MAX_LENGTH,
)
from obsei.payload import TextPayload
logger = logging.getLogger(__name__)
class TransformersNERAnalyzer(BaseAnalyzer):
_pipeline: Pipeline = PrivateAttr()
_max_length: int = PrivateAttr()
TYPE: str = "NER"
model_name_or_path: str
tokenizer_name: Optional[str] = None
grouped_entities: Optional[bool] = True
def __init__(self, **data: Any):
super().__init__(**data)
model = AutoModelForTokenClassification.from_pretrained(self.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_name if self.tokenizer_name else self.model_name_or_path,
use_fast=True,
)
self._pipeline = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
grouped_entities=self.grouped_entities,
device=self._device_id,
)
if hasattr(self._pipeline.model.config, "max_position_embeddings"):
self._max_length = self._pipeline.model.config.max_position_embeddings
else:
self._max_length = MAX_LENGTH
def _prediction_from_model(self, texts: List[str]) -> List[List[Dict[str, float]]]:
prediction = self._pipeline(texts)
return ( # type: ignore[no-any-return]
prediction
if len(prediction) and isinstance(prediction[0], list)
else [prediction]
)
def analyze_input(
self,
source_response_list: List[TextPayload],
analyzer_config: Optional[BaseAnalyzerConfig] = None,
**kwargs: Any,
) -> List[TextPayload]:
analyzer_output: List[TextPayload] = []
for batch_responses in self.batchify(source_response_list, self.batch_size):
texts = [
source_response.processed_text[: self._max_length]
for source_response in batch_responses
]
batch_predictions = self._prediction_from_model(texts)
for prediction, source_response in zip(batch_predictions, batch_responses):
segmented_data = {"ner_data": prediction}
if source_response.segmented_data:
segmented_data = {
**segmented_data,
**source_response.segmented_data,
}
analyzer_output.append(
TextPayload(
processed_text=source_response.processed_text,
meta=source_response.meta,
segmented_data=segmented_data,
source_name=source_response.source_name,
)
)
return analyzer_output
class SpacyNERAnalyzer(BaseAnalyzer):
_nlp: Language = PrivateAttr()
TYPE: str = "NER"
model_name_or_path: str
tokenizer_name: Optional[str] = None
grouped_entities: Optional[bool] = True
n_process: int = 1
def __init__(self, **data: Any):
super().__init__(**data)
self._nlp = spacy.load(
self.model_name_or_path,
disable=["tagger", "parser", "attribute_ruler", "lemmatizer"],
)
def _spacy_pipe_batchify(
self,
texts: List[str],
batch_size: int,
source_response_list: List[TextPayload],
) -> Generator[Tuple[Iterator[Doc], List[TextPayload]], None, None]:
for index in range(0, len(texts), batch_size):
yield (
self._nlp.pipe(
texts=texts[index: index + batch_size],
batch_size=batch_size,
n_process=self.n_process,
),
source_response_list[index: index + batch_size],
)
def analyze_input(
self,
source_response_list: List[TextPayload],
analyzer_config: Optional[BaseAnalyzerConfig] = None,
**kwargs: Any,
) -> List[TextPayload]:
analyzer_output: List[TextPayload] = []
texts = [
source_response.processed_text for source_response in source_response_list
]
for batch_docs, batch_source_response in self._spacy_pipe_batchify(
texts, self.batch_size, source_response_list
):
for doc, source_response in zip(batch_docs, batch_source_response):
ner_prediction = [
{
"entity_group": ent.label_,
"word": ent.text,
"start": ent.start_char,
"end": ent.end_char,
}
for ent in doc.ents
]
segmented_data = {"ner_data": ner_prediction}
if source_response.segmented_data:
segmented_data = {
**segmented_data,
**source_response.segmented_data,
}
analyzer_output.append(
TextPayload(
processed_text=source_response.processed_text,
meta=source_response.meta,
segmented_data=segmented_data,
source_name=source_response.source_name,
)
)
return analyzer_output