from abc import abstractmethod
from typing import Any, Generator, List, Optional

from pydantic import Field, PrivateAttr
from pydantic_settings import BaseSettings

from obsei.misc import gpu_util
from obsei.payload import TextPayload
from obsei.postprocessor.inference_aggregator import (
    InferenceAggregator,
    InferenceAggregatorConfig,
)
from obsei.preprocessor.text_splitter import TextSplitter, TextSplitterConfig
from obsei.workflow.base_store import BaseStore

MAX_LENGTH: int = 510
DEFAULT_BATCH_SIZE_GPU: int = 64
DEFAULT_BATCH_SIZE_CPU: int = 4


class BaseAnalyzerConfig(BaseSettings):
    TYPE: str = "Base"
    use_splitter_and_aggregator: Optional[bool] = False
    splitter_config: Optional[TextSplitterConfig] = None
    aggregator_config: Optional[InferenceAggregatorConfig] = None

    def __init__(self, **data: Any):
        super().__init__(**data)

        if self.use_splitter_and_aggregator and not self.splitter_config and not self.aggregator_config:
            raise AttributeError("Need splitter_config and aggregator_config if enabling use_splitter_and_aggregator "
                                 "option")

    class Config:
        arbitrary_types_allowed = True


class BaseAnalyzer(BaseSettings):
    _device_id: int = PrivateAttr()
    TYPE: str = "Base"
    store: Optional[BaseStore] = None
    device: str = "auto"
    batch_size: int = -1
    splitter: TextSplitter = Field(default=TextSplitter())
    aggregator: InferenceAggregator = Field(default=InferenceAggregator())

    """
        auto: choose gpu if present else use cpu
        cpu: use cpu
        cuda:{id} - cuda device id
    """

    def __init__(self, **data: Any):
        super().__init__(**data)

        self._device_id = gpu_util.get_device_id(self.device)
        if self.batch_size < 0:
            self.batch_size = (
                DEFAULT_BATCH_SIZE_CPU
                if self._device_id == 0
                else DEFAULT_BATCH_SIZE_GPU
            )

    @abstractmethod
    def analyze_input(
        self,
        source_response_list: List[TextPayload],
        analyzer_config: Optional[BaseAnalyzerConfig] = None,
        **kwargs: Any,
    ) -> List[TextPayload]:
        pass

    @staticmethod
    def batchify(
        payload_list: List[TextPayload],
        batch_size: int,
    ) -> Generator[List[TextPayload], None, None]:
        for index in range(0, len(payload_list), batch_size):
            yield payload_list[index : index + batch_size]

    class Config:
        arbitrary_types_allowed = True