|
|
|
import os |
|
import gradio as gr |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
|
from langchain_community.document_transformers import LongContextReorder |
|
from config import LLM_MODEL, STREAMING |
|
from embeddings import get_embeddings |
|
from retrievers import load_retrievers |
|
from llm import get_llm |
|
from prompt import get_prompt |
|
|
|
|
|
def create_rag_chain(chat_history): |
|
embeddings = get_embeddings() |
|
retriever = load_retrievers(embeddings) |
|
llm = get_llm(streaming=STREAMING) |
|
prompt = get_prompt(chat_history) |
|
|
|
return ( |
|
{ |
|
"context": retriever |
|
| RunnableLambda(LongContextReorder().transform_documents), |
|
"question": RunnablePassthrough(), |
|
} |
|
| prompt |
|
| llm.with_config(configurable={"llm": LLM_MODEL}) |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
def respond_stream(message, history): |
|
rag_chain = create_rag_chain(history) |
|
response = "" |
|
for chunk in rag_chain.stream(message): |
|
response += chunk |
|
yield response |
|
|
|
|
|
def respond(message, history): |
|
rag_chain = create_rag_chain(history) |
|
return rag_chain.invoke(message) |
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond_stream if STREAMING else respond, |
|
title="νλ‘μ λν΄μ λ¬Όμ΄λ³΄μΈμ!", |
|
description="μλ
νμΈμ!\nμ λ νλ‘μ λν μΈκ³΅μ§λ₯ QAλ΄μ
λλ€. νλ‘μ λν΄ κΉμ μ§μμ κ°μ§κ³ μμ΄μ. νλ‘μ κ΄ν λμμ΄ νμνμλ©΄ μΈμ λ μ§ μ§λ¬Έν΄μ£ΌμΈμ!", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|