File size: 1,600 Bytes
56487d0
 
21c61a3
56487d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21c61a3
56487d0
 
 
21c61a3
 
56487d0
 
 
 
21c61a3
 
56487d0
 
 
21c61a3
 
 
56487d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# app.py
import os
import gradio as gr
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder
from config import LLM_MODEL, STREAMING
from embeddings import get_embeddings
from retrievers import load_retrievers
from llm import get_llm
from prompt import get_prompt


def create_rag_chain(chat_history):
    embeddings = get_embeddings()
    retriever = load_retrievers(embeddings)
    llm = get_llm(streaming=STREAMING)
    prompt = get_prompt(chat_history)

    return (
        {
            "context": retriever
            | RunnableLambda(LongContextReorder().transform_documents),
            "question": RunnablePassthrough(),
        }
        | prompt
        | llm.with_config(configurable={"llm": LLM_MODEL})
        | StrOutputParser()
    )


def respond_stream(message, history):
    rag_chain = create_rag_chain(history)
    response = ""
    for chunk in rag_chain.stream(message):
        response += chunk
        yield response


def respond(message, history):
    rag_chain = create_rag_chain(history)
    return rag_chain.invoke(message)


demo = gr.ChatInterface(
    respond_stream if STREAMING else respond,
    title="νŒλ‘€μ— λŒ€ν•΄μ„œ λ¬Όμ–΄λ³΄μ„Έμš”!",
    description="μ•ˆλ…•ν•˜μ„Έμš”!\nμ €λŠ” νŒλ‘€μ— λŒ€ν•œ 인곡지λŠ₯ QAλ΄‡μž…λ‹ˆλ‹€. νŒλ‘€μ— λŒ€ν•΄ κΉŠμ€ 지식을 κ°€μ§€κ³  μžˆμ–΄μš”. νŒλ‘€μ— κ΄€ν•œ 도움이 ν•„μš”ν•˜μ‹œλ©΄ μ–Έμ œλ“ μ§€ μ§ˆλ¬Έν•΄μ£Όμ„Έμš”!",
)

if __name__ == "__main__":
    demo.launch()