Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import time | |
import json | |
import psycopg2 | |
from typing import Dict, List | |
# from dotenv import load_dotenv | |
# FastAPI λ° slowapi κ΄λ ¨ λͺ¨λ | |
from fastapi import FastAPI, Request | |
from slowapi import Limiter, _rate_limit_exceeded_handler | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
from fastapi.middleware.cors import CORSMiddleware | |
# Pydantic λͺ¨λΈ | |
from pydantic import BaseModel | |
# LangChain κ΄λ ¨ λͺ¨λ | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
# from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import PGVector | |
from langchain_core.messages import SystemMessage | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.documents import Document # Document νμ ννΈμ©μΌλ‘ μΆκ° | |
# from pdf_importer import create_vector_store, CONNECTION_STRING, COLLECTION_NAME | |
# νκ²½ λ³μ λ‘λ (Hugging Face Secretsμμ κ°μ Έμ΄) | |
POSTGRES_USER = os.getenv('POSTGRES_USER') | |
POSTGRES_PASSWORD = os.getenv('POSTGRES_PASSWORD') | |
POSTGRES_HOST = os.getenv('POSTGRES_HOST') | |
POSTGRES_PORT = os.getenv('POSTGRES_PORT') | |
POSTGRES_DB = os.getenv('POSTGRES_DB') | |
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
COLLECTION_NAME = "homepage_pdfplumner_1st" | |
SENTENCE_TRANSFORMERS_HOME = os.getenv('SENTENCE_TRANSFORMERS_HOME', '/app/.cache') | |
# 2. νμ νκ²½ λ³μκ° λͺ¨λ μ‘΄μ¬νλμ§ νμΈν©λλ€. | |
if not all([POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB, GOOGLE_API_KEY, SENTENCE_TRANSFORMERS_HOME]): | |
raise ValueError("νμ νκ²½ λ³μλ€μ΄ μ€μ λμ§ μμμ΅λλ€. Hugging Face Secretsλ₯Ό νμΈνμΈμ.") | |
# νκ²½ λ³μλ₯Ό μ‘°ν©νμ¬ CONNECTION_STRINGμ μμ± | |
CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" | |
# load_dotenv() | |
app = FastAPI() | |
limiter = Limiter(key_func=get_remote_address) | |
app.state.limiter = limiter | |
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# RAG κ΅¬μ± μμλ₯Ό νλ‘κ·Έλ¨ μμ μ ν λ²λ§ μ΄κΈ°ν | |
embeddings = HuggingFaceEmbeddings( | |
model_name='nlpai-lab/KURE-v1', | |
model_kwargs={'device': 'cpu'}, | |
cache_folder=SENTENCE_TRANSFORMERS_HOME | |
) | |
try: | |
vector_store = PGVector( | |
collection_name=COLLECTION_NAME, | |
connection_string=CONNECTION_STRING, | |
embedding_function=embeddings | |
) | |
print("Vector store loaded from PostgreSQL.") | |
except Exception as e: | |
print(f"Error connecting to PostgreSQL: {e}") | |
import sys | |
sys.exit(1) | |
llm = ChatGoogleGenerativeAI( | |
# model="gemini-1.5-flash-8b", | |
model="gemini-2.5-flash-lite", | |
model_kwargs={ | |
"system_instruction": SystemMessage( | |
content= | |
# """λΉμ μ νκ΅μΈκ΅μ΄λνκ΅(μμΈ) νμ¬ μ λ¬Έκ°μ λλ€. λ΅λ³ μμΉ: 1. νκ΅μΈκ΅μ΄λνκ΅(μμΈ) κ΄λ ¨ μ§λ¬Έμ μ νν λ΅λ³ν©λλ€. 2. μ΄μ λν λ§₯λ½μ κΈ°μ΅νκ³ μ μ°νκ² μλ΅ν©λλ€. 3. μΉμ νκ³ μ΄ν΄νκΈ° μ¬μ΄ λ§ν¬λ₯Ό μ¬μ©νλ©°, λ°λμ μμ ν λ¬Έμ₯μΌλ‘ λ΅λ³ν©λλ€. 4. μ°Έκ³ μ 보μ μλ λ΄μ©μ μ λ μΆμΈ‘νκ±°λ μμλ‘ λ΅λ³νμ§ μμ΅λλ€. λ΅λ³ κ·μΉ: - νκ΅μΈκ΅μ΄λνκ΅(μμΈ)κ³Ό κ΄λ ¨ μλ μ§λ¬Έ: "μ£μ‘ν©λλ€. νκ΅μΈκ΅μ΄λνκ΅(μμΈ) κ΄λ ¨ μ§λ¬Έμλ§ λ΅λ³λ릴 μ μμ΅λλ€."λΌκ³ λ΅λ³νμΈμ. - μ¬μ©μμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ λ³΄κ° μ°Έκ³ λ¬Έμμ λͺ ννκ² μ‘΄μ¬νμ§ μλ κ²½μ°, μ΄λ€ λ΄μ©λ μΆλ‘ νκ±°λ λ§λΆμ΄μ§ λ§κ³ 무쑰건 "μ£μ‘ν©λλ€. ν΄λΉ μ 보λ₯Ό νμΈν μ μμ΅λλ€."λΌκ³ λ΅λ³νμΈμ.""" | |
""" | |
λΉμ μ νκ΅μΈκ΅μ΄λνκ΅(μμΈ)μ **'νμ¬ μν AI μ΄λλ°μ΄μ '**μ λλ€. λΉμ μ μ§μμ μ£Όμ΄μ§ [νμ¬ κ·μ ]κ³Ό [μ£Όλ³ μκΆ μ 보] λ¬Έμλ‘ νμ λ©λλ€. λΉμ μ μ무λ μ΄ μ§μ λ΄μμ νμλ€μ μ§λ¬Έμ λͺ ννκ³ μΉμ ν μ λ¬Έκ°μ μ΄μ‘°λ‘ λ΅λ³νλ κ²μ λλ€. | |
[λ΅λ³ μμΉ] | |
1. μ νμ±: λ°λμ μ£Όμ΄μ§ μ°Έκ³ λ¬Έμμ λ΄μ©μλ§ κ·Όκ±°νμ¬ λ΅λ³ν©λλ€. | |
2. μΉμ ν¨: νμ μΉμ νκ³ μ΄ν΄νκΈ° μ¬μ΄ μμ ν λ¬Έμ₯μΌλ‘ λ΅λ³ν©λλ€. | |
3. λ§₯λ½ μ΄ν΄: μ΄μ λν λ΄μ©μ κΈ°μ΅νμ¬ μμ°μ€λ¬μ΄ λνλ₯Ό μ΄μ΄κ°λλ€. | |
4. μ§μ λ΄μ¬ν: λΉμ μ λ¬Έμλ₯Ό λ¨μν μ λ¬νλ λ‘λ΄μ΄ μλλλ€. μ£Όμ΄μ§ μ°Έκ³ λ¬Έμλ λΉμ μ 'μ§μ'μ λλ€. λ΅λ³ μ, 'μ 곡λ μ 보', 'μ°Έκ³ λ¬Έμ', 'μ£Όμ΄μ§ ν μ€νΈ', 'ν', 'λ¬Έλ¨' λ± λΉμ μ΄ μ 보λ₯Ό μ΄λ»κ² μ»μλμ§ μμνλ κ·Έ μ΄λ€ λ¨μ΄λ μ λ μ¬μ©νμ§ λ§μΈμ. κ²μλ λͺ¨λ μ 보λ₯Ό μμ ν μμ μ μ§μμΈ κ²μ²λΌ μ’ ν©νκ³ μμ°μ€λ½κ² μ¬κ΅¬μ±νμ¬, λ§μΉ μλλΆν° μκ³ μμλ κ²μ²λΌ μ¬μ©μμκ² μ§μ μ€λͺ ν΄μΌ ν©λλ€. | |
5. νκ΅μ΄ μ¬μ©: λͺ¨λ λ΅λ³μ λ°λμ μλ²½ν νκ΅μ΄λ‘λ§ μμ±ν΄μΌ ν©λλ€. | |
[λ΅λ³ κ·μΉ] | |
1. μκΈ°μκ°: λ§μ½ μ¬μ©μκ° λΉμ μ μ 체μ±μ λν΄ λ¬»λλ€λ©΄(μ: "λλ λꡬμΌ?", "μ΄λ¦μ΄ λμΌ?"), "μλ νμΈμ! μ λ νκ΅μΈκ΅μ΄λνκ΅ νμλ€μ μΊ νΌμ€ μνμ λκΈ° μν΄ λ§λ€μ΄μ§ 'νμ¬ μν AI μ΄λλ°μ΄μ 'μ λλ€. νμ¬ μ 보λ νκ΅ μνμ λν΄ κΆκΈν μ μ΄ μλ€λ©΄ 무μμ΄λ λ¬Όμ΄λ³΄μΈμ." λΌκ³ μ νν μκ°ν΄μΌ ν©λλ€. μ λλ‘ 'Googleμ μΈμ΄ λͺ¨λΈ'μ΄λ λ§μ€μ½νΈ 'λΆ(Boo)'λΌκ³ μμ μ μκ°ν΄μλ μ λ©λλ€. | |
1. λ²μ μΈ μ§λ¬Έ νλ¨: λΉμ μ μ§μ λ²μ(νμ¬, μ£Όλ³ λ§μ§)μ λͺ λ°±ν κ΄λ ¨ μλ μ§λ¬Έ(μ: κΈμ΅, μ€ν¬μΈ )μλ "μ£μ‘ν©λλ€. μ λ νκ΅μΈκ΅μ΄λνκ΅ νμ¬ λ° μΊ νΌμ€ μν μ 보μ λν΄μλ§ λ΅λ³ν μ μμ΅λλ€." λΌκ³ λ΅λ³νμΈμ. 'μ 곡λ μ 보μ μλ€'λ μμ λΆμ° μ€λͺ μ μ λ λ§λΆμ΄μ§ λ§μΈμ. | |
2. μ 보 μ°μ μμ νλ³: μ¬λ¬ κ°μ μ°Έκ³ λ¬Έμκ° μ£Όμ΄μ§λ©΄, κ·Έμ€μμ μ¬μ©μμ μ§λ¬Έμ κ°μ₯ μ§μ μ μΌλ‘ λ΅ν μ μλ ν΅μ¬ μ 보λ₯Ό λ¨Όμ μλ³νμΈμ. κ΄λ ¨μ±μ΄ λ¨μ΄μ§κ±°λ λΆμ°¨μ μΈ μ 보λ λ΅λ³μ ν¬ν¨νμ§ μκ±°λ, κΌ νμν κ²½μ°μλ§ κ°λ΅νκ² λ§λΆμ¬ μ€λͺ νμΈμ. | |
3. ν(Table) λΆμ: μ°Έκ³ λ¬Έμμ νκ° ν¬ν¨λ κ²½μ°, λΉμ μ ν λΆμ μ λ¬Έκ°λ‘μ νκ³Ό μ΄μ κ΄κ³λ₯Ό μ νν ν΄μνμ¬ λ΅λ³ν΄μΌ ν©λλ€. | |
4. μ‘°κ±΄λΆ λ΅λ³: λ§μ½ νλ ν μ€νΈμ νκ³Ό, νλ² λ± μΈλΆ μ‘°κ±΄μ΄ λͺ μλμ΄ μμ§ μλ€λ©΄, "μ μλ μλ£μ λ°λ₯΄λ©΄ μΌλ°μ μΌλ‘" λλ "2025νλ λ κΈ°μ€μΌλ‘λ" κ³Ό κ°μ΄ μ 보μ μΆμ²λ κΈ°μ€μ λͺ νν λ°νλ©° λ΅λ³νμΈμ. | |
5. λ€μ€ μ 보 μ²λ¦¬: λ§μ½ μ¬μ©μμ μ§λ¬Έμ λν΄ μ¬λ¬ λ¬Έμμμ μλ‘ λ€λ₯Έ μ λ³΄κ° κ²μλ κ²½μ°, νλμ μ λ³΄λ§ μ ννμ§ λ§μΈμ. λμ , κ°κ°μ 쑰건과 λ΄μ©μ λͺ νν ꡬλΆνμ¬ λͺ¨λ μ 보λ₯Ό μ’ ν©μ μΌλ‘ μλ΄ν΄μΌ ν©λλ€. | |
6. μμΈ κ°λ₯μ± μΈμ§: νμ¬ κ·μ μ λ¨κ³Όλν, νκ³Ό, νλ²λ³λ‘ μμΈ κ·μΉμ΄ μ‘΄μ¬ν μ μλ€λ μ¬μ€μ νμ μΈμ§νμΈμ. λ§μ½ μΌλ°μ μΈ κ·μΉμ μ°ΎμλλΌλ, "μΌλ°μ μΌλ‘λ OOνμ μ΄ νμνμ§λ§, μμ λ¨κ³Όλνμ΄λ νκ³Όμ λ°λΌ λ€λ₯Ό μ μμΌλ μ νν μ 보λ νκ΅ κ³΅μ λ¬Έμλ₯Ό νμΈνμκ±°λ νκ³Ό μ¬λ¬΄μ€μ λ¬Έμνλ κ²μ κΆμ₯ν©λλ€" μ κ°μ΄ λ΅λ³μ 'μ£Όμμ¬ν'κ³Ό 'νκ³'λ₯Ό λͺ μνμΈμ. | |
7. μ 보 λΆμ¬ μ: μμ λͺ¨λ λ Έλ ₯μλ λΆκ΅¬νκ³ μ§λ¬Έμ λν λ΅λ³μ μ°Έκ³ λ¬Έμμμ μ°Ύμ μ μλ κ²½μ°μλ§, "μ£μ‘ν©λλ€. λ¬Έμνμ λ΄μ©μ λν μ 보λ μ κ° κ°μ§ μλ£μμ νμΈν μ μμ΅λλ€."λΌκ³ λ΅λ³νμΈμ. | |
""" | |
), | |
} | |
) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
# retriever = MultiQueryRetriever.from_llm( | |
# retriever=vector_store.as_retriever(search_kwargs={"k": 5}), | |
# llm=llm | |
# ) | |
# μ¬μ©μ μΈμ λ³ λν 체μΈμ μ μ₯ν λμ λ리 | |
chat_sessions: Dict[str, ConversationalRetrievalChain] = {} | |
def get_or_create_chain(session_id: str) -> ConversationalRetrievalChain: | |
if session_id not in chat_sessions: | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, | |
input_key="question", # <-- μΆκ° | |
output_key="answer" ) | |
new_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=True, # μ°Έκ³ λ¬Έμ λ°ν νμ±ν | |
output_key="answer" | |
) | |
chat_sessions[session_id] = new_chain | |
print(f"μλ‘μ΄ μΈμ ID μμ±: {session_id}") | |
return chat_sessions[session_id] | |
class ChatMessage(BaseModel): | |
message: str | |
session_id: str | |
user_id: str # μ¬μ©μ μλ³μ μν΄ μΆκ° | |
class ChatResponse(BaseModel): | |
response: str | |
success: bool | |
# source_documents νλλ₯Ό μΆκ°νμ¬ νλ‘ νΈμλλ‘λ λ³΄λΌ μ μλλ‘ μ€λΉ | |
source_documents: List[Dict[str, str]] = [] # λ¬Έμ λ΄μ©κ³Ό λ©νλ°μ΄ν° μ μ₯ | |
async def chat_with_gemini(request: Request): | |
start_time = time.time() | |
try: | |
# JSON bodyλ₯Ό μ§μ νμ± | |
body = await request.json() | |
chat_message = ChatMessage(**body) | |
# qa_chain = get_or_create_chain(request.session_id) | |
# result = qa_chain.invoke({"question": request.message}) | |
qa_chain = get_or_create_chain(chat_message.session_id) | |
result = qa_chain.invoke({"question": chat_message.message}) | |
# μ°Έκ³ λ¬Έμ μΆμΆ λ° λ‘κ·Έ μΆλ ₯ | |
source_documents_for_response: List[Dict[str, str]] = [] | |
if 'source_documents' in result and result['source_documents']: | |
print("\n--- μ°Έκ³ λ¬Έμ ---") | |
for i, doc in enumerate(result['source_documents']): | |
print(f"λ¬Έμ {i+1}:") | |
print(f" μμ€: {doc.metadata.get('source', 'μ μ μμ')}") | |
print(f" λ΄μ© (μΌλΆ): {doc.page_content[:200]}...") # λ΄μ©μ μΌλΆλ§ μΆλ ₯ | |
# νλ‘ νΈμλ μλ΅μ μν΄ μ μ₯ | |
source_documents_for_response.append({ | |
"source": doc.metadata.get('source', 'μ μ μμ'), | |
"content": doc.page_content # μ 체 λ΄μ©μ λ³΄λΌ μλ μμ | |
}) | |
print("---------------\n") | |
# ========================================================== | |
# βΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌ μ΄ λΆλΆλ§ μΆκ° βΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌβΌ | |
# ========================================================== | |
response_time_ms = int((time.time() - start_time) * 1000) | |
# DBμ λ‘κ·Έ μ μ₯ | |
try: | |
db_conn_str = CONNECTION_STRING.replace("postgresql+psycopg2", "postgresql") | |
conn = psycopg2.connect(db_conn_str) | |
cur = conn.cursor() | |
cur.execute( | |
""" | |
INSERT INTO chat_logs (session_id, user_id, user_question, bot_answer, retrieved_sources, response_time_ms) | |
VALUES (%s, %s, %s, %s, %s, %s); | |
""", | |
(chat_message.session_id, chat_message.user_id, chat_message.message, result['answer'], json.dumps(source_documents_for_response), response_time_ms) | |
) | |
conn.commit() | |
cur.close() | |
conn.close() | |
except Exception as db_error: | |
print(f"DB λ‘κ·Έ μ μ₯ μ€ν¨: {db_error}") | |
# ========================================================== | |
# β²β²β²β²β²β²β²β²β²β²β²β²β²β²β²β²β²β² μ΄ λΆλΆλ§ μΆκ° β²β²β²β²β²β²β²β²β²β²β²β²β²β²β²β²β²β² | |
# ========================================================== | |
return ChatResponse( | |
response=result['answer'], | |
success=True, | |
source_documents=source_documents_for_response # μλ΅μ μ°Έκ³ λ¬Έμ μΆκ° | |
) | |
except Exception as e: | |
print(f"μ€λ₯ λ°μ: {str(e)}") | |
return ChatResponse( | |
response=f"μ€λ₯κ° λ°μνμ΅λλ€: {str(e)}", | |
success=False, | |
source_documents=[] | |
) | |
async def root(): | |
return {"message": "νκ΅μΈκ΅μ΄λνκ΅(μμΈ) νμ¬ μ±λ΄ API"} | |