csc525_retrieval_based_chatbot / chatbot_config.py
JoeArmani
sentence transformer
64e7c31
raw
history blame
1.08 kB
from dataclasses import dataclass
from pathlib import Path
from typing import Dict
@dataclass
class ChatbotConfig:
"""RetrievalChatbot Config"""
max_context_token_limit: int = 512
embedding_dim: int = 384 # Match Sentence Transformer dimension
learning_rate: float = 0.0005
min_text_length: int = 3
max_context_turns: int = 20
pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
summarizer_model: str = 't5-small'
embedding_batch_size: int = 64
search_batch_size: int = 64
max_batch_size: int = 64
max_retries: int = 3
def to_dict(self) -> Dict:
"""Convert config to dictionary."""
return {k: (str(v) if isinstance(v, Path) else v)
for k, v in self.__dict__.items()}
@classmethod
def from_dict(cls, config_dict: Dict) -> 'ChatbotConfig':
"""Create config from dictionary."""
return cls(**{k: v for k, v in config_dict.items()
if k in cls.__dataclass_fields__})