d0rj's picture
feat: Initial commit
1719436
raw
history blame
855 Bytes
import pathlib
from typing import Literal, Any, get_args
from pydantic import BaseModel
from pydantic_yaml import parse_yaml_raw_as
from src.generate.llms import LLMName
from src.generate.generators import GeneratorName
class GenerationConfig(BaseModel):
build_function: GeneratorName = get_args(GeneratorName)[0]
llm_class: LLMName = get_args(LLMName)[0]
structured_output_method: Literal[
"function_calling", "json_mode", "json_schema"
] = "json_schema"
kwargs: dict[str, Any] = {}
@classmethod
def from_yaml(cls, yaml_str: str) -> "GenerationConfig":
return parse_yaml_raw_as(cls, yaml_str)
@classmethod
def from_file(cls, file_path: str | pathlib.Path) -> "GenerationConfig":
with open(file_path, "r") as file:
yaml_str = file.read()
return cls.from_yaml(yaml_str)