d0rj's picture
feat: Initial commit
1719436
raw
history blame
3.98 kB
from typing import Any, Literal, Callable
import openai
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.prompts import (
load_prompt,
ChatPromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from src.common.paths import PROMPTS_PATH
from src.common.schema import DatasetSchema
from src.generate.llms import LLM_NAME_TO_CLASS, LLMName
class GenerationAnswer(BaseModel):
answer: Any
context: dict[str, Any] = {}
def build_singleturn_chain(
answer_class: type[BaseModel],
llm_class: LLMName = "ollama",
llm_args: dict[str, Any] = {
"model": "gemma3:4b",
"top_k": 1,
"top_p": 1,
"temperature": 0.0,
},
structured_output_method: Literal[
"function_calling", "json_mode", "json_schema"
] = "json_schema",
) -> Runnable:
llm = LLM_NAME_TO_CLASS[llm_class](
**llm_args,
)
llm = llm.with_structured_output(
answer_class,
method=structured_output_method,
)
prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate(
prompt=load_prompt(PROMPTS_PATH / "singleturn.yaml")
)
]
)
chain = RunnablePassthrough.assign(answer=prompt | llm) | RunnableLambda(
lambda x: GenerationAnswer(
answer=x["answer"],
context={},
)
)
chain = chain.with_retry(
retry_if_exception_type=(openai.PermissionDeniedError, )
)
return chain
def build_thinking_chain(
answer_class: type[BaseModel],
llm_class: LLMName = "ollama",
think_llm_args: dict[str, Any] = {
"model": "gemma3:4b",
"top_k": 1,
"top_p": 1,
"temperature": 0.0,
},
answer_llm_args: dict[str, Any] = {
"model": "gemma3:4b",
"top_k": 1,
"top_p": 1,
"temperature": 0.0,
},
structured_output_method: Literal[
"function_calling", "json_mode", "json_schema"
] = "json_schema",
) -> Runnable:
think_llm = LLM_NAME_TO_CLASS[llm_class](
**think_llm_args,
)
think_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(
prompt=load_prompt(PROMPTS_PATH / "simple_think_system.yaml")
),
HumanMessagePromptTemplate.from_template(f"{{{DatasetSchema.task_text}}}"),
]
)
think_chain = think_prompt | think_llm | StrOutputParser()
answer_prompt = ChatPromptTemplate.from_messages(
think_prompt.messages
+ [
AIMessagePromptTemplate.from_template("{think_answer}"),
HumanMessagePromptTemplate(
prompt=load_prompt(PROMPTS_PATH / "simple_think_end.yaml")
),
]
)
answer_llm = LLM_NAME_TO_CLASS[llm_class](
**answer_llm_args,
)
answer_llm = answer_llm.with_structured_output(
answer_class,
method=structured_output_method,
)
chain = (
RunnablePassthrough.assign(
think_answer=think_chain,
)
| RunnablePassthrough.assign(answer=answer_prompt | answer_llm)
| RunnableLambda(
lambda x: GenerationAnswer(
answer=x["answer"],
context={
"think_answer": x["think_answer"],
},
)
)
)
chain = chain.with_retry(
retry_if_exception_type=(openai.PermissionDeniedError, )
)
return chain
GeneratorName = Literal["singleturn", "thinking"]
GENERATORS_NAME_TO_FACTORY: dict[str, Callable[[type[BaseModel]], Runnable]] = {
"singleturn": build_singleturn_chain,
"thinking": build_thinking_chain,
}