kltn20133118's picture
Upload 337 files
dbaa71b verified
raw
history blame
2.54 kB
from abc import abstractmethod
from typing import Any, Generator, List, Optional
from pydantic import Field, PrivateAttr
from pydantic_settings import BaseSettings
from obsei.misc import gpu_util
from obsei.payload import TextPayload
from obsei.postprocessor.inference_aggregator import (
InferenceAggregator,
InferenceAggregatorConfig,
)
from obsei.preprocessor.text_splitter import TextSplitter, TextSplitterConfig
from obsei.workflow.base_store import BaseStore
MAX_LENGTH: int = 510
DEFAULT_BATCH_SIZE_GPU: int = 64
DEFAULT_BATCH_SIZE_CPU: int = 4
class BaseAnalyzerConfig(BaseSettings):
TYPE: str = "Base"
use_splitter_and_aggregator: Optional[bool] = False
splitter_config: Optional[TextSplitterConfig] = None
aggregator_config: Optional[InferenceAggregatorConfig] = None
def __init__(self, **data: Any):
super().__init__(**data)
if self.use_splitter_and_aggregator and not self.splitter_config and not self.aggregator_config:
raise AttributeError("Need splitter_config and aggregator_config if enabling use_splitter_and_aggregator "
"option")
class Config:
arbitrary_types_allowed = True
class BaseAnalyzer(BaseSettings):
_device_id: int = PrivateAttr()
TYPE: str = "Base"
store: Optional[BaseStore] = None
device: str = "auto"
batch_size: int = -1
splitter: TextSplitter = Field(default=TextSplitter())
aggregator: InferenceAggregator = Field(default=InferenceAggregator())
"""
auto: choose gpu if present else use cpu
cpu: use cpu
cuda:{id} - cuda device id
"""
def __init__(self, **data: Any):
super().__init__(**data)
self._device_id = gpu_util.get_device_id(self.device)
if self.batch_size < 0:
self.batch_size = (
DEFAULT_BATCH_SIZE_CPU
if self._device_id == 0
else DEFAULT_BATCH_SIZE_GPU
)
@abstractmethod
def analyze_input(
self,
source_response_list: List[TextPayload],
analyzer_config: Optional[BaseAnalyzerConfig] = None,
**kwargs: Any,
) -> List[TextPayload]:
pass
@staticmethod
def batchify(
payload_list: List[TextPayload],
batch_size: int,
) -> Generator[List[TextPayload], None, None]:
for index in range(0, len(payload_list), batch_size):
yield payload_list[index : index + batch_size]
class Config:
arbitrary_types_allowed = True