File size: 3,160 Bytes
56487d0 21c61a3 de850e8 ceaa913 de850e8 56487d0 de850e8 273182d de850e8 19ceb64 de850e8 56487d0 ceaa913 273182d ceaa913 273182d 56487d0 273182d 56487d0 ceaa913 273182d 56487d0 273182d 21c61a3 ceaa913 273182d 56487d0 21c61a3 ceaa913 273182d ceaa913 273182d 21c61a3 56487d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import gradio as gr
from kiwipiepy import Kiwi
from typing import List, Tuple, Generator, Union
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder
from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def kiwi_tokenize(text):
kiwi = Kiwi()
return [token.form for token in kiwi.tokenize(text)]
embeddings = get_embeddings()
retriever = load_retrievers(embeddings)
# μ¬μ© κ°λ₯ν λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μλ³μ, value: μ¬μ©μμκ² νμν λ μ΄λΈ)
AVAILABLE_MODELS = {
"gpt_3_5_turbo": "GPT-3.5 Turbo",
"gpt_4o": "GPT-4o",
"claude_3_5_sonnet": "Claude 3.5 Sonnet",
"gemini_1_5_flash": "Gemini 1.5 Flash",
"llama3_70b": "Llama3 70b",
}
def create_rag_chain(chat_history: List[Tuple[str, str]], model: str):
llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
prompt = get_prompt(chat_history)
return (
{
"context": retriever
| RunnableLambda(LongContextReorder().transform_documents),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
)
def get_model_key(label):
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def respond_stream(
message: str, history: List[Tuple[str, str]], model: str
) -> Generator[str, None, None]:
rag_chain = create_rag_chain(history, model)
for chunk in rag_chain.stream(message):
yield chunk
def respond(message: str, history: List[Tuple[str, str]], model: str) -> str:
rag_chain = create_rag_chain(history, model)
return rag_chain.invoke(message)
def get_model_key(label: str) -> str:
return next(key for key, value in AVAILABLE_MODELS.items() if value == label)
def chat_function(
message: str, history: List[Tuple[str, str]], model_label: str
) -> Generator[str, None, None]:
model_key = get_model_key(model_label)
if STREAMING:
response = ""
for chunk in respond_stream(message, history, model_key):
response += chunk
yield response
else:
response = respond(message, history, model_key)
yield response
with gr.Blocks() as demo:
gr.Markdown("# λλ²μ νλ‘ μλ΄ λμ°λ―Έ")
gr.Markdown(
"μλ
νμΈμ! λλ²μ νλ‘μ κ΄ν μ§λ¬Έμ λ΅λ³ν΄λ리λ AI μλ΄ λμ°λ―Έμ
λλ€. νλ‘ κ²μ, ν΄μ, μ μ© λ±μ λν΄ κΆκΈνμ μ μ΄ μμΌλ©΄ μΈμ λ λ¬Όμ΄λ³΄μΈμ."
)
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.values()),
label="λͺ¨λΈ μ ν",
value=list(AVAILABLE_MODELS.values())[1],
)
chatbot = gr.ChatInterface(
fn=chat_function,
additional_inputs=[model_dropdown],
)
if __name__ == "__main__":
demo.launch()
|