Spaces:
Sleeping
Sleeping
from typing import Any, Literal, Callable | |
import openai | |
from pydantic import BaseModel | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import Runnable, RunnableLambda | |
from langchain_core.prompts import ( | |
load_prompt, | |
ChatPromptTemplate, | |
AIMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
SystemMessagePromptTemplate, | |
) | |
from src.common.paths import PROMPTS_PATH | |
from src.common.schema import DatasetSchema | |
from src.generate.llms import LLM_NAME_TO_CLASS, LLMName | |
class GenerationAnswer(BaseModel): | |
answer: Any | |
context: dict[str, Any] = {} | |
def build_singleturn_chain( | |
answer_class: type[BaseModel], | |
llm_class: LLMName = "ollama", | |
llm_args: dict[str, Any] = { | |
"model": "gemma3:4b", | |
"top_k": 1, | |
"top_p": 1, | |
"temperature": 0.0, | |
}, | |
structured_output_method: Literal[ | |
"function_calling", "json_mode", "json_schema" | |
] = "json_schema", | |
) -> Runnable: | |
llm = LLM_NAME_TO_CLASS[llm_class]( | |
**llm_args, | |
) | |
llm = llm.with_structured_output( | |
answer_class, | |
method=structured_output_method, | |
) | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
HumanMessagePromptTemplate( | |
prompt=load_prompt(PROMPTS_PATH / "singleturn.yaml") | |
) | |
] | |
) | |
chain = RunnablePassthrough.assign(answer=prompt | llm) | RunnableLambda( | |
lambda x: GenerationAnswer( | |
answer=x["answer"], | |
context={}, | |
) | |
) | |
chain = chain.with_retry( | |
retry_if_exception_type=(openai.PermissionDeniedError, ) | |
) | |
return chain | |
def build_thinking_chain( | |
answer_class: type[BaseModel], | |
llm_class: LLMName = "ollama", | |
think_llm_args: dict[str, Any] = { | |
"model": "gemma3:4b", | |
"top_k": 1, | |
"top_p": 1, | |
"temperature": 0.0, | |
}, | |
answer_llm_args: dict[str, Any] = { | |
"model": "gemma3:4b", | |
"top_k": 1, | |
"top_p": 1, | |
"temperature": 0.0, | |
}, | |
structured_output_method: Literal[ | |
"function_calling", "json_mode", "json_schema" | |
] = "json_schema", | |
) -> Runnable: | |
think_llm = LLM_NAME_TO_CLASS[llm_class]( | |
**think_llm_args, | |
) | |
think_prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate( | |
prompt=load_prompt(PROMPTS_PATH / "simple_think_system.yaml") | |
), | |
HumanMessagePromptTemplate.from_template(f"{{{DatasetSchema.task_text}}}"), | |
] | |
) | |
think_chain = think_prompt | think_llm | StrOutputParser() | |
answer_prompt = ChatPromptTemplate.from_messages( | |
think_prompt.messages | |
+ [ | |
AIMessagePromptTemplate.from_template("{think_answer}"), | |
HumanMessagePromptTemplate( | |
prompt=load_prompt(PROMPTS_PATH / "simple_think_end.yaml") | |
), | |
] | |
) | |
answer_llm = LLM_NAME_TO_CLASS[llm_class]( | |
**answer_llm_args, | |
) | |
answer_llm = answer_llm.with_structured_output( | |
answer_class, | |
method=structured_output_method, | |
) | |
chain = ( | |
RunnablePassthrough.assign( | |
think_answer=think_chain, | |
) | |
| RunnablePassthrough.assign(answer=answer_prompt | answer_llm) | |
| RunnableLambda( | |
lambda x: GenerationAnswer( | |
answer=x["answer"], | |
context={ | |
"think_answer": x["think_answer"], | |
}, | |
) | |
) | |
) | |
chain = chain.with_retry( | |
retry_if_exception_type=(openai.PermissionDeniedError, ) | |
) | |
return chain | |
GeneratorName = Literal["singleturn", "thinking"] | |
GENERATORS_NAME_TO_FACTORY: dict[str, Callable[[type[BaseModel]], Runnable]] = { | |
"singleturn": build_singleturn_chain, | |
"thinking": build_thinking_chain, | |
} | |