Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import time | |
import json | |
import psycopg2 | |
from typing import Dict, List | |
# from dotenv import load_dotenv | |
# FastAPI ๋ฐ slowapi ๊ด๋ จ ๋ชจ๋ | |
from fastapi import FastAPI, Request | |
from slowapi import Limiter, _rate_limit_exceeded_handler | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
from fastapi.middleware.cors import CORSMiddleware | |
# Pydantic ๋ชจ๋ธ | |
from pydantic import BaseModel | |
# LangChain ๊ด๋ จ ๋ชจ๋ | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
# from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import PGVector | |
from langchain_core.messages import SystemMessage | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.documents import Document # Document ํ์ ํํธ์ฉ์ผ๋ก ์ถ๊ฐ | |
# from pdf_importer import create_vector_store, CONNECTION_STRING, COLLECTION_NAME | |
# ํ๊ฒฝ ๋ณ์ ๋ก๋ (Hugging Face Secrets์์ ๊ฐ์ ธ์ด) | |
POSTGRES_USER = os.getenv('POSTGRES_USER') | |
POSTGRES_PASSWORD = os.getenv('POSTGRES_PASSWORD') | |
POSTGRES_HOST = os.getenv('POSTGRES_HOST') | |
POSTGRES_PORT = os.getenv('POSTGRES_PORT') | |
POSTGRES_DB = os.getenv('POSTGRES_DB') | |
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
COLLECTION_NAME = "homepage_pdfplumner_1st" | |
SENTENCE_TRANSFORMERS_HOME = os.getenv('SENTENCE_TRANSFORMERS_HOME', '/app/.cache') | |
# 2. ํ์ ํ๊ฒฝ ๋ณ์๊ฐ ๋ชจ๋ ์กด์ฌํ๋์ง ํ์ธํฉ๋๋ค. | |
if not all([POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB, GOOGLE_API_KEY, SENTENCE_TRANSFORMERS_HOME]): | |
raise ValueError("ํ์ ํ๊ฒฝ ๋ณ์๋ค์ด ์ค์ ๋์ง ์์์ต๋๋ค. Hugging Face Secrets๋ฅผ ํ์ธํ์ธ์.") | |
# ํ๊ฒฝ ๋ณ์๋ฅผ ์กฐํฉํ์ฌ CONNECTION_STRING์ ์์ฑ | |
CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" | |
# load_dotenv() | |
app = FastAPI() | |
limiter = Limiter(key_func=get_remote_address) | |
app.state.limiter = limiter | |
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# RAG ๊ตฌ์ฑ ์์๋ฅผ ํ๋ก๊ทธ๋จ ์์ ์ ํ ๋ฒ๋ง ์ด๊ธฐํ | |
embeddings = HuggingFaceEmbeddings( | |
model_name='nlpai-lab/KURE-v1', | |
model_kwargs={'device': 'cpu'} | |
) | |
try: | |
vector_store = PGVector( | |
collection_name=COLLECTION_NAME, | |
connection_string=CONNECTION_STRING, | |
embedding_function=embeddings | |
) | |
print("Vector store loaded from PostgreSQL.") | |
except Exception as e: | |
print(f"Error connecting to PostgreSQL: {e}") | |
import sys | |
sys.exit(1) | |
llm = ChatGoogleGenerativeAI( | |
# model="gemini-1.5-flash-8b", | |
model="gemini-2.5-flash-lite", | |
model_kwargs={ | |
"system_instruction": SystemMessage( | |
content= | |
# """๋น์ ์ ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต(์์ธ) ํ์ฌ ์ ๋ฌธ๊ฐ์ ๋๋ค. ๋ต๋ณ ์์น: 1. ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต(์์ธ) ๊ด๋ จ ์ง๋ฌธ์ ์ ํํ ๋ต๋ณํฉ๋๋ค. 2. ์ด์ ๋ํ ๋งฅ๋ฝ์ ๊ธฐ์ตํ๊ณ ์ ์ฐํ๊ฒ ์๋ตํฉ๋๋ค. 3. ์น์ ํ๊ณ ์ดํดํ๊ธฐ ์ฌ์ด ๋งํฌ๋ฅผ ์ฌ์ฉํ๋ฉฐ, ๋ฐ๋์ ์์ ํ ๋ฌธ์ฅ์ผ๋ก ๋ต๋ณํฉ๋๋ค. 4. ์ฐธ๊ณ ์ ๋ณด์ ์๋ ๋ด์ฉ์ ์ ๋ ์ถ์ธกํ๊ฑฐ๋ ์์๋ก ๋ต๋ณํ์ง ์์ต๋๋ค. ๋ต๋ณ ๊ท์น: - ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต(์์ธ)๊ณผ ๊ด๋ จ ์๋ ์ง๋ฌธ: "์ฃ์กํฉ๋๋ค. ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต(์์ธ) ๊ด๋ จ ์ง๋ฌธ์๋ง ๋ต๋ณ๋๋ฆด ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์. - ์ฌ์ฉ์์ ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ์ ๋ณด๊ฐ ์ฐธ๊ณ ๋ฌธ์์ ๋ช ํํ๊ฒ ์กด์ฌํ์ง ์๋ ๊ฒฝ์ฐ, ์ด๋ค ๋ด์ฉ๋ ์ถ๋ก ํ๊ฑฐ๋ ๋ง๋ถ์ด์ง ๋ง๊ณ ๋ฌด์กฐ๊ฑด "์ฃ์กํฉ๋๋ค. ํด๋น ์ ๋ณด๋ฅผ ํ์ธํ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์.""" | |
""" | |
๋น์ ์ ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต(์์ธ)์ **'ํ์ฌ ์ํ AI ์ด๋๋ฐ์ด์ '**์ ๋๋ค. ๋น์ ์ ์ง์์ ์ฃผ์ด์ง [ํ์ฌ ๊ท์ ]๊ณผ [์ฃผ๋ณ ์๊ถ ์ ๋ณด] ๋ฌธ์๋ก ํ์ ๋ฉ๋๋ค. ๋น์ ์ ์๋ฌด๋ ์ด ์ง์ ๋ด์์ ํ์๋ค์ ์ง๋ฌธ์ ๋ช ํํ๊ณ ์น์ ํ ์ ๋ฌธ๊ฐ์ ์ด์กฐ๋ก ๋ต๋ณํ๋ ๊ฒ์ ๋๋ค. | |
[๋ต๋ณ ์์น] | |
1. ์ ํ์ฑ: ๋ฐ๋์ ์ฃผ์ด์ง ์ฐธ๊ณ ๋ฌธ์์ ๋ด์ฉ์๋ง ๊ทผ๊ฑฐํ์ฌ ๋ต๋ณํฉ๋๋ค. | |
2. ์น์ ํจ: ํญ์ ์น์ ํ๊ณ ์ดํดํ๊ธฐ ์ฌ์ด ์์ ํ ๋ฌธ์ฅ์ผ๋ก ๋ต๋ณํฉ๋๋ค. | |
3. ๋งฅ๋ฝ ์ดํด: ์ด์ ๋ํ ๋ด์ฉ์ ๊ธฐ์ตํ์ฌ ์์ฐ์ค๋ฌ์ด ๋ํ๋ฅผ ์ด์ด๊ฐ๋๋ค. | |
4. ์ง์ ๋ด์ฌํ: ๋น์ ์ ๋ฌธ์๋ฅผ ๋จ์ํ ์ ๋ฌํ๋ ๋ก๋ด์ด ์๋๋๋ค. ์ฃผ์ด์ง ์ฐธ๊ณ ๋ฌธ์๋ ๋น์ ์ '์ง์'์ ๋๋ค. ๋ต๋ณ ์, '์ ๊ณต๋ ์ ๋ณด', '์ฐธ๊ณ ๋ฌธ์', '์ฃผ์ด์ง ํ ์คํธ', 'ํ', '๋ฌธ๋จ' ๋ฑ ๋น์ ์ด ์ ๋ณด๋ฅผ ์ด๋ป๊ฒ ์ป์๋์ง ์์ํ๋ ๊ทธ ์ด๋ค ๋จ์ด๋ ์ ๋ ์ฌ์ฉํ์ง ๋ง์ธ์. ๊ฒ์๋ ๋ชจ๋ ์ ๋ณด๋ฅผ ์์ ํ ์์ ์ ์ง์์ธ ๊ฒ์ฒ๋ผ ์ข ํฉํ๊ณ ์์ฐ์ค๋ฝ๊ฒ ์ฌ๊ตฌ์ฑํ์ฌ, ๋ง์น ์๋๋ถํฐ ์๊ณ ์์๋ ๊ฒ์ฒ๋ผ ์ฌ์ฉ์์๊ฒ ์ง์ ์ค๋ช ํด์ผ ํฉ๋๋ค. | |
5. ํ๊ตญ์ด ์ฌ์ฉ: ๋ชจ๋ ๋ต๋ณ์ ๋ฐ๋์ ์๋ฒฝํ ํ๊ตญ์ด๋ก๋ง ์์ฑํด์ผ ํฉ๋๋ค. | |
[๋ต๋ณ ๊ท์น] | |
1. ์๊ธฐ์๊ฐ: ๋ง์ฝ ์ฌ์ฉ์๊ฐ ๋น์ ์ ์ ์ฒด์ฑ์ ๋ํด ๋ฌป๋๋ค๋ฉด(์: "๋๋ ๋๊ตฌ์ผ?", "์ด๋ฆ์ด ๋ญ์ผ?"), "์๋ ํ์ธ์! ์ ๋ ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต ํ์๋ค์ ์บ ํผ์ค ์ํ์ ๋๊ธฐ ์ํด ๋ง๋ค์ด์ง 'ํ์ฌ ์ํ AI ์ด๋๋ฐ์ด์ '์ ๋๋ค. ํ์ฌ ์ ๋ณด๋ ํ๊ต ์ํ์ ๋ํด ๊ถ๊ธํ ์ ์ด ์๋ค๋ฉด ๋ฌด์์ด๋ ๋ฌผ์ด๋ณด์ธ์." ๋ผ๊ณ ์ ํํ ์๊ฐํด์ผ ํฉ๋๋ค. ์ ๋๋ก 'Google์ ์ธ์ด ๋ชจ๋ธ'์ด๋ ๋ง์ค์ฝํธ '๋ถ(Boo)'๋ผ๊ณ ์์ ์ ์๊ฐํด์๋ ์ ๋ฉ๋๋ค. | |
1. ๋ฒ์ ์ธ ์ง๋ฌธ ํ๋จ: ๋น์ ์ ์ง์ ๋ฒ์(ํ์ฌ, ์ฃผ๋ณ ๋ง์ง)์ ๋ช ๋ฐฑํ ๊ด๋ จ ์๋ ์ง๋ฌธ(์: ๊ธ์ต, ์คํฌ์ธ )์๋ "์ฃ์กํฉ๋๋ค. ์ ๋ ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต ํ์ฌ ๋ฐ ์บ ํผ์ค ์ํ ์ ๋ณด์ ๋ํด์๋ง ๋ต๋ณํ ์ ์์ต๋๋ค." ๋ผ๊ณ ๋ต๋ณํ์ธ์. '์ ๊ณต๋ ์ ๋ณด์ ์๋ค'๋ ์์ ๋ถ์ฐ ์ค๋ช ์ ์ ๋ ๋ง๋ถ์ด์ง ๋ง์ธ์. | |
2. ์ ๋ณด ์ฐ์ ์์ ํ๋ณ: ์ฌ๋ฌ ๊ฐ์ ์ฐธ๊ณ ๋ฌธ์๊ฐ ์ฃผ์ด์ง๋ฉด, ๊ทธ์ค์์ ์ฌ์ฉ์์ ์ง๋ฌธ์ ๊ฐ์ฅ ์ง์ ์ ์ผ๋ก ๋ตํ ์ ์๋ ํต์ฌ ์ ๋ณด๋ฅผ ๋จผ์ ์๋ณํ์ธ์. ๊ด๋ จ์ฑ์ด ๋จ์ด์ง๊ฑฐ๋ ๋ถ์ฐจ์ ์ธ ์ ๋ณด๋ ๋ต๋ณ์ ํฌํจํ์ง ์๊ฑฐ๋, ๊ผญ ํ์ํ ๊ฒฝ์ฐ์๋ง ๊ฐ๋ตํ๊ฒ ๋ง๋ถ์ฌ ์ค๋ช ํ์ธ์. | |
3. ํ(Table) ๋ถ์: ์ฐธ๊ณ ๋ฌธ์์ ํ๊ฐ ํฌํจ๋ ๊ฒฝ์ฐ, ๋น์ ์ ํ ๋ถ์ ์ ๋ฌธ๊ฐ๋ก์ ํ๊ณผ ์ด์ ๊ด๊ณ๋ฅผ ์ ํํ ํด์ํ์ฌ ๋ต๋ณํด์ผ ํฉ๋๋ค. | |
4. ์กฐ๊ฑด๋ถ ๋ต๋ณ: ๋ง์ฝ ํ๋ ํ ์คํธ์ ํ๊ณผ, ํ๋ฒ ๋ฑ ์ธ๋ถ ์กฐ๊ฑด์ด ๋ช ์๋์ด ์์ง ์๋ค๋ฉด, "์ ์๋ ์๋ฃ์ ๋ฐ๋ฅด๋ฉด ์ผ๋ฐ์ ์ผ๋ก" ๋๋ "2025ํ๋ ๋ ๊ธฐ์ค์ผ๋ก๋" ๊ณผ ๊ฐ์ด ์ ๋ณด์ ์ถ์ฒ๋ ๊ธฐ์ค์ ๋ช ํํ ๋ฐํ๋ฉฐ ๋ต๋ณํ์ธ์. | |
5. ๋ค์ค ์ ๋ณด ์ฒ๋ฆฌ: ๋ง์ฝ ์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ํด ์ฌ๋ฌ ๋ฌธ์์์ ์๋ก ๋ค๋ฅธ ์ ๋ณด๊ฐ ๊ฒ์๋ ๊ฒฝ์ฐ, ํ๋์ ์ ๋ณด๋ง ์ ํํ์ง ๋ง์ธ์. ๋์ , ๊ฐ๊ฐ์ ์กฐ๊ฑด๊ณผ ๋ด์ฉ์ ๋ช ํํ ๊ตฌ๋ถํ์ฌ ๋ชจ๋ ์ ๋ณด๋ฅผ ์ข ํฉ์ ์ผ๋ก ์๋ดํด์ผ ํฉ๋๋ค. | |
6. ์์ธ ๊ฐ๋ฅ์ฑ ์ธ์ง: ํ์ฌ ๊ท์ ์ ๋จ๊ณผ๋ํ, ํ๊ณผ, ํ๋ฒ๋ณ๋ก ์์ธ ๊ท์น์ด ์กด์ฌํ ์ ์๋ค๋ ์ฌ์ค์ ํญ์ ์ธ์งํ์ธ์. ๋ง์ฝ ์ผ๋ฐ์ ์ธ ๊ท์น์ ์ฐพ์๋๋ผ๋, "์ผ๋ฐ์ ์ผ๋ก๋ OOํ์ ์ด ํ์ํ์ง๋ง, ์์ ๋จ๊ณผ๋ํ์ด๋ ํ๊ณผ์ ๋ฐ๋ผ ๋ค๋ฅผ ์ ์์ผ๋ ์ ํํ ์ ๋ณด๋ ํ๊ต ๊ณต์ ๋ฌธ์๋ฅผ ํ์ธํ์๊ฑฐ๋ ํ๊ณผ ์ฌ๋ฌด์ค์ ๋ฌธ์ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค" ์ ๊ฐ์ด ๋ต๋ณ์ '์ฃผ์์ฌํญ'๊ณผ 'ํ๊ณ'๋ฅผ ๋ช ์ํ์ธ์. | |
7. ์ ๋ณด ๋ถ์ฌ ์: ์์ ๋ชจ๋ ๋ ธ๋ ฅ์๋ ๋ถ๊ตฌํ๊ณ ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์ฐธ๊ณ ๋ฌธ์์์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ์๋ง, "์ฃ์กํฉ๋๋ค. ๋ฌธ์ํ์ ๋ด์ฉ์ ๋ํ ์ ๋ณด๋ ์ ๊ฐ ๊ฐ์ง ์๋ฃ์์ ํ์ธํ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์. | |
""" | |
), | |
} | |
) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
# retriever = MultiQueryRetriever.from_llm( | |
# retriever=vector_store.as_retriever(search_kwargs={"k": 5}), | |
# llm=llm | |
# ) | |
# ์ฌ์ฉ์ ์ธ์ ๋ณ ๋ํ ์ฒด์ธ์ ์ ์ฅํ ๋์ ๋๋ฆฌ | |
chat_sessions: Dict[str, ConversationalRetrievalChain] = {} | |
def get_or_create_chain(session_id: str) -> ConversationalRetrievalChain: | |
if session_id not in chat_sessions: | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, | |
input_key="question", # <-- ์ถ๊ฐ | |
output_key="answer" ) | |
new_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=True, # ์ฐธ๊ณ ๋ฌธ์ ๋ฐํ ํ์ฑํ | |
output_key="answer" | |
) | |
chat_sessions[session_id] = new_chain | |
print(f"์๋ก์ด ์ธ์ ID ์์ฑ: {session_id}") | |
return chat_sessions[session_id] | |
class ChatMessage(BaseModel): | |
message: str | |
session_id: str | |
user_id: str # ์ฌ์ฉ์ ์๋ณ์ ์ํด ์ถ๊ฐ | |
class ChatResponse(BaseModel): | |
response: str | |
success: bool | |
# source_documents ํ๋๋ฅผ ์ถ๊ฐํ์ฌ ํ๋ก ํธ์๋๋ก๋ ๋ณด๋ผ ์ ์๋๋ก ์ค๋น | |
source_documents: List[Dict[str, str]] = [] # ๋ฌธ์ ๋ด์ฉ๊ณผ ๋ฉํ๋ฐ์ดํฐ ์ ์ฅ | |
async def chat_with_gemini(request: Request): | |
start_time = time.time() | |
try: | |
# JSON body๋ฅผ ์ง์ ํ์ฑ | |
body = await request.json() | |
chat_message = ChatMessage(**body) | |
# qa_chain = get_or_create_chain(request.session_id) | |
# result = qa_chain.invoke({"question": request.message}) | |
qa_chain = get_or_create_chain(chat_message.session_id) | |
result = qa_chain.invoke({"question": chat_message.message}) | |
# ์ฐธ๊ณ ๋ฌธ์ ์ถ์ถ ๋ฐ ๋ก๊ทธ ์ถ๋ ฅ | |
source_documents_for_response: List[Dict[str, str]] = [] | |
if 'source_documents' in result and result['source_documents']: | |
print("\n--- ์ฐธ๊ณ ๋ฌธ์ ---") | |
for i, doc in enumerate(result['source_documents']): | |
print(f"๋ฌธ์ {i+1}:") | |
print(f" ์์ค: {doc.metadata.get('source', '์ ์ ์์')}") | |
print(f" ๋ด์ฉ (์ผ๋ถ): {doc.page_content[:200]}...") # ๋ด์ฉ์ ์ผ๋ถ๋ง ์ถ๋ ฅ | |
# ํ๋ก ํธ์๋ ์๋ต์ ์ํด ์ ์ฅ | |
source_documents_for_response.append({ | |
"source": doc.metadata.get('source', '์ ์ ์์'), | |
"content": doc.page_content # ์ ์ฒด ๋ด์ฉ์ ๋ณด๋ผ ์๋ ์์ | |
}) | |
print("---------------\n") | |
# ========================================================== | |
# โผโผโผโผโผโผโผโผโผโผโผโผโผโผโผโผโผโผ ์ด ๋ถ๋ถ๋ง ์ถ๊ฐ โผโผโผโผโผโผโผโผโผโผโผโผโผโผโผโผโผโผ | |
# ========================================================== | |
response_time_ms = int((time.time() - start_time) * 1000) | |
# DB์ ๋ก๊ทธ ์ ์ฅ | |
try: | |
db_conn_str = CONNECTION_STRING.replace("postgresql+psycopg2", "postgresql") | |
conn = psycopg2.connect(db_conn_str) | |
cur = conn.cursor() | |
cur.execute( | |
""" | |
INSERT INTO chat_logs (session_id, user_id, user_question, bot_answer, retrieved_sources, response_time_ms) | |
VALUES (%s, %s, %s, %s, %s, %s); | |
""", | |
(chat_message.session_id, chat_message.user_id, chat_message.message, result['answer'], json.dumps(source_documents_for_response), response_time_ms) | |
) | |
conn.commit() | |
cur.close() | |
conn.close() | |
except Exception as db_error: | |
print(f"DB ๋ก๊ทธ ์ ์ฅ ์คํจ: {db_error}") | |
# ========================================================== | |
# โฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒ ์ด ๋ถ๋ถ๋ง ์ถ๊ฐ โฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒโฒ | |
# ========================================================== | |
return ChatResponse( | |
response=result['answer'], | |
success=True, | |
source_documents=source_documents_for_response # ์๋ต์ ์ฐธ๊ณ ๋ฌธ์ ์ถ๊ฐ | |
) | |
except Exception as e: | |
print(f"์ค๋ฅ ๋ฐ์: {str(e)}") | |
return ChatResponse( | |
response=f"์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}", | |
success=False, | |
source_documents=[] | |
) | |
async def root(): | |
return {"message": "ํ๊ตญ์ธ๊ตญ์ด๋ํ๊ต(์์ธ) ํ์ฌ ์ฑ๋ด API"} | |