Spaces:
Sleeping
Sleeping
from abc import abstractmethod | |
from typing import Any, Generator, List, Optional | |
from pydantic import Field, PrivateAttr | |
from pydantic_settings import BaseSettings | |
from obsei.misc import gpu_util | |
from obsei.payload import TextPayload | |
from obsei.postprocessor.inference_aggregator import ( | |
InferenceAggregator, | |
InferenceAggregatorConfig, | |
) | |
from obsei.preprocessor.text_splitter import TextSplitter, TextSplitterConfig | |
from obsei.workflow.base_store import BaseStore | |
MAX_LENGTH: int = 510 | |
DEFAULT_BATCH_SIZE_GPU: int = 64 | |
DEFAULT_BATCH_SIZE_CPU: int = 4 | |
class BaseAnalyzerConfig(BaseSettings): | |
TYPE: str = "Base" | |
use_splitter_and_aggregator: Optional[bool] = False | |
splitter_config: Optional[TextSplitterConfig] = None | |
aggregator_config: Optional[InferenceAggregatorConfig] = None | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
if self.use_splitter_and_aggregator and not self.splitter_config and not self.aggregator_config: | |
raise AttributeError("Need splitter_config and aggregator_config if enabling use_splitter_and_aggregator " | |
"option") | |
class Config: | |
arbitrary_types_allowed = True | |
class BaseAnalyzer(BaseSettings): | |
_device_id: int = PrivateAttr() | |
TYPE: str = "Base" | |
store: Optional[BaseStore] = None | |
device: str = "auto" | |
batch_size: int = -1 | |
splitter: TextSplitter = Field(default=TextSplitter()) | |
aggregator: InferenceAggregator = Field(default=InferenceAggregator()) | |
""" | |
auto: choose gpu if present else use cpu | |
cpu: use cpu | |
cuda:{id} - cuda device id | |
""" | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
self._device_id = gpu_util.get_device_id(self.device) | |
if self.batch_size < 0: | |
self.batch_size = ( | |
DEFAULT_BATCH_SIZE_CPU | |
if self._device_id == 0 | |
else DEFAULT_BATCH_SIZE_GPU | |
) | |
def analyze_input( | |
self, | |
source_response_list: List[TextPayload], | |
analyzer_config: Optional[BaseAnalyzerConfig] = None, | |
**kwargs: Any, | |
) -> List[TextPayload]: | |
pass | |
def batchify( | |
payload_list: List[TextPayload], | |
batch_size: int, | |
) -> Generator[List[TextPayload], None, None]: | |
for index in range(0, len(payload_list), batch_size): | |
yield payload_list[index : index + batch_size] | |
class Config: | |
arbitrary_types_allowed = True | |