Spaces:
Sleeping
Sleeping
import logging | |
from typing import Any, Dict, List, Optional | |
from pydantic import Field, PrivateAttr | |
from transformers import Pipeline, pipeline | |
from obsei.analyzer.base_analyzer import ( | |
BaseAnalyzer, | |
BaseAnalyzerConfig, | |
MAX_LENGTH, | |
) | |
from obsei.payload import TextPayload | |
from obsei.postprocessor.inference_aggregator import InferenceAggregatorConfig | |
from obsei.postprocessor.inference_aggregator_function import ClassificationAverageScore | |
logger = logging.getLogger(__name__) | |
class ClassificationAnalyzerConfig(BaseAnalyzerConfig): | |
TYPE: str = "Classification" | |
labels: Optional[List[str]] = None | |
label_map: Optional[Dict[str, str]] = None | |
multi_class_classification: bool = True | |
add_positive_negative_labels: bool = True | |
aggregator_config: InferenceAggregatorConfig = Field( | |
InferenceAggregatorConfig(aggregate_function=ClassificationAverageScore()) | |
) | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
if self.labels is None: | |
self.multi_class_classification = False | |
self.add_positive_negative_labels = False | |
class TextClassificationAnalyzer(BaseAnalyzer): | |
TYPE: str = "Classification" | |
pipeline_name: str = "text-classification" | |
_pipeline: Pipeline = PrivateAttr() | |
_max_length: int = PrivateAttr() | |
model_name_or_path: str | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
self._pipeline = pipeline( | |
self.pipeline_name, | |
model=self.model_name_or_path, | |
device=self._device_id, | |
) | |
if hasattr(self._pipeline.model.config, "max_position_embeddings"): | |
self._max_length = self._pipeline.model.config.max_position_embeddings | |
else: | |
self._max_length = MAX_LENGTH | |
def prediction_from_model( | |
self, | |
texts: List[str], | |
analyzer_config: Optional[ClassificationAnalyzerConfig] = None, | |
) -> List[Dict[str, Any]]: | |
prediction = self._pipeline(texts) | |
predictions = prediction if isinstance(prediction, list) else [prediction] | |
label_map = analyzer_config.label_map if analyzer_config is not None else {} | |
label_map = label_map or {} | |
return [ | |
{ | |
label_map.get(prediction["label"], prediction["label"]): prediction["score"] | |
} for prediction in predictions | |
] | |
def analyze_input( # type: ignore[override] | |
self, | |
source_response_list: List[TextPayload], | |
analyzer_config: Optional[ClassificationAnalyzerConfig] = None, | |
**kwargs: Any, | |
) -> List[TextPayload]: | |
analyzer_output: List[TextPayload] = [] | |
if ( | |
analyzer_config is not None | |
and analyzer_config.use_splitter_and_aggregator | |
and analyzer_config.splitter_config | |
): | |
source_response_list = self.splitter.preprocess_input( | |
source_response_list, | |
config=analyzer_config.splitter_config, | |
) | |
for batch_responses in self.batchify(source_response_list, self.batch_size): | |
texts = [ | |
source_response.processed_text[: self._max_length] | |
for source_response in batch_responses | |
] | |
batch_predictions = self.prediction_from_model(texts=texts, analyzer_config=analyzer_config) | |
for score_dict, source_response in zip(batch_predictions, batch_responses): | |
segmented_data = { | |
"classifier_data": score_dict | |
} | |
if source_response.segmented_data: | |
segmented_data = { | |
**segmented_data, | |
**source_response.segmented_data, | |
} | |
analyzer_output.append( | |
TextPayload( | |
processed_text=source_response.processed_text, | |
meta=source_response.meta, | |
segmented_data=segmented_data, | |
source_name=source_response.source_name, | |
) | |
) | |
if ( | |
analyzer_config is not None | |
and analyzer_config.use_splitter_and_aggregator | |
and analyzer_config.aggregator_config | |
): | |
analyzer_output = self.aggregator.postprocess_input( | |
input_list=analyzer_output, | |
config=analyzer_config.aggregator_config, | |
) | |
return analyzer_output | |
class ZeroShotClassificationAnalyzer(TextClassificationAnalyzer): | |
pipeline_name: str = "zero-shot-classification" | |
def prediction_from_model( | |
self, | |
texts: List[str], | |
analyzer_config: Optional[ClassificationAnalyzerConfig] = None, | |
) -> List[Dict[str, Any]]: | |
if analyzer_config is None: | |
raise ValueError("analyzer_config can't be None") | |
labels = analyzer_config.labels or [] | |
if analyzer_config.add_positive_negative_labels: | |
if "positive" not in labels: | |
labels.append("positive") | |
if "negative" not in labels: | |
labels.append("negative") | |
if len(labels) == 0: | |
raise ValueError("`labels` can't be empty or `add_positive_negative_labels` should be False") | |
prediction = self._pipeline( | |
texts, candidate_labels=labels, multi_label=analyzer_config.multi_class_classification | |
) | |
predictions = prediction if isinstance(prediction, list) else [prediction] | |
return [dict(zip(prediction["labels"], prediction["scores"])) for prediction in predictions] | |
def analyze_input( # type: ignore[override] | |
self, | |
source_response_list: List[TextPayload], | |
analyzer_config: Optional[ClassificationAnalyzerConfig] = None, | |
**kwargs: Any, | |
) -> List[TextPayload]: | |
if analyzer_config is None: | |
raise ValueError("analyzer_config can't be None") | |
return super().analyze_input( | |
source_response_list=source_response_list, | |
analyzer_config=analyzer_config, | |
**kwargs | |
) | |