File size: 3,160 Bytes
56487d0
21c61a3
de850e8
ceaa913
de850e8
56487d0
 
 
de850e8
273182d
de850e8
 
 
 
 
19ceb64
 
de850e8
 
 
 
 
 
 
 
56487d0
 
ceaa913
 
 
 
 
 
 
 
273182d
 
ceaa913
273182d
56487d0
 
 
 
 
 
 
 
 
273182d
56487d0
 
 
 
ceaa913
 
 
 
 
 
 
273182d
56487d0
273182d
21c61a3
 
ceaa913
273182d
56487d0
 
21c61a3
ceaa913
273182d
 
 
ceaa913
 
 
273182d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21c61a3
 
56487d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union

from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder

from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def kiwi_tokenize(text):
    kiwi = Kiwi()
    return [token.form for token in kiwi.tokenize(text)]


embeddings = get_embeddings()
retriever = load_retrievers(embeddings)


# μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μ‹λ³„μž, value: μ‚¬μš©μžμ—κ²Œ ν‘œμ‹œν•  λ ˆμ΄λΈ”)
AVAILABLE_MODELS = {
    "gpt_3_5_turbo": "GPT-3.5 Turbo",
    "gpt_4o": "GPT-4o",
    "claude_3_5_sonnet": "Claude 3.5 Sonnet",
    "gemini_1_5_flash": "Gemini 1.5 Flash",
    "llama3_70b": "Llama3 70b",
}


def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
    llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
    prompt = get_prompt(chat_history)

    return (
        {
            "context": retriever
            | RunnableLambda(LongContextReorder().transform_documents),
            "question": RunnablePassthrough(),
        }
        | prompt
        | llm
        | StrOutputParser()
    )


def get_model_key(label):
    return next(key for key, value in AVAILABLE_MODELS.items() if value == label)


def respond_stream(
    message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
    rag_chain = create_rag_chain(history, model)
    for chunk in rag_chain.stream(message):
        yield chunk


def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
    rag_chain = create_rag_chain(history, model)
    return rag_chain.invoke(message)


def get_model_key(label: str) -> str:
    return next(key for key, value in AVAILABLE_MODELS.items() if value == label)


def chat_function(
    message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
    model_key = get_model_key(model_label)
    if STREAMING:
        response = ""
        for chunk in respond_stream(message, history, model_key):
            response += chunk
            yield response
    else:
        response = respond(message, history, model_key)
        yield response


with gr.Blocks() as demo:
    gr.Markdown("# λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ")
    gr.Markdown(
        "μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”."
    )

    model_dropdown = gr.Dropdown(
        choices=list(AVAILABLE_MODELS.values()),
        label="λͺ¨λΈ 선택",
        value=list(AVAILABLE_MODELS.values())[1],
    )

    chatbot = gr.ChatInterface(
        fn=chat_function,
        additional_inputs=[model_dropdown],
    )

if __name__ == "__main__":
    demo.launch()