from typing import Any, Literal, Callable import openai from pydantic import BaseModel from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import Runnable, RunnableLambda from langchain_core.prompts import ( load_prompt, ChatPromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) from src.common.paths import PROMPTS_PATH from src.common.schema import DatasetSchema from src.generate.llms import LLM_NAME_TO_CLASS, LLMName class GenerationAnswer(BaseModel): answer: Any context: dict[str, Any] = {} def build_singleturn_chain( answer_class: type[BaseModel], llm_class: LLMName = "ollama", llm_args: dict[str, Any] = { "model": "gemma3:4b", "top_k": 1, "top_p": 1, "temperature": 0.0, }, structured_output_method: Literal[ "function_calling", "json_mode", "json_schema" ] = "json_schema", ) -> Runnable: llm = LLM_NAME_TO_CLASS[llm_class]( **llm_args, ) llm = llm.with_structured_output( answer_class, method=structured_output_method, ) prompt = ChatPromptTemplate.from_messages( [ HumanMessagePromptTemplate( prompt=load_prompt(PROMPTS_PATH / "singleturn.yaml") ) ] ) chain = RunnablePassthrough.assign(answer=prompt | llm) | RunnableLambda( lambda x: GenerationAnswer( answer=x["answer"], context={}, ) ) chain = chain.with_retry( retry_if_exception_type=(openai.PermissionDeniedError, ) ) return chain def build_thinking_chain( answer_class: type[BaseModel], llm_class: LLMName = "ollama", think_llm_args: dict[str, Any] = { "model": "gemma3:4b", "top_k": 1, "top_p": 1, "temperature": 0.0, }, answer_llm_args: dict[str, Any] = { "model": "gemma3:4b", "top_k": 1, "top_p": 1, "temperature": 0.0, }, structured_output_method: Literal[ "function_calling", "json_mode", "json_schema" ] = "json_schema", ) -> Runnable: think_llm = LLM_NAME_TO_CLASS[llm_class]( **think_llm_args, ) think_prompt = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate( prompt=load_prompt(PROMPTS_PATH / "simple_think_system.yaml") ), HumanMessagePromptTemplate.from_template(f"{{{DatasetSchema.task_text}}}"), ] ) think_chain = think_prompt | think_llm | StrOutputParser() answer_prompt = ChatPromptTemplate.from_messages( think_prompt.messages + [ AIMessagePromptTemplate.from_template("{think_answer}"), HumanMessagePromptTemplate( prompt=load_prompt(PROMPTS_PATH / "simple_think_end.yaml") ), ] ) answer_llm = LLM_NAME_TO_CLASS[llm_class]( **answer_llm_args, ) answer_llm = answer_llm.with_structured_output( answer_class, method=structured_output_method, ) chain = ( RunnablePassthrough.assign( think_answer=think_chain, ) | RunnablePassthrough.assign(answer=answer_prompt | answer_llm) | RunnableLambda( lambda x: GenerationAnswer( answer=x["answer"], context={ "think_answer": x["think_answer"], }, ) ) ) chain = chain.with_retry( retry_if_exception_type=(openai.PermissionDeniedError, ) ) return chain GeneratorName = Literal["singleturn", "thinking"] GENERATORS_NAME_TO_FACTORY: dict[str, Callable[[type[BaseModel]], Runnable]] = { "singleturn": build_singleturn_chain, "thinking": build_thinking_chain, }