Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import faiss | |
app = FastAPI(title="Closed-Domain Q&A Chatbot") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Data store | |
questions = [] | |
answers = [] | |
index = None | |
model = None | |
class Question(BaseModel): | |
query: str | |
async def load_qa(file: UploadFile = File(...)): | |
global questions, answers, index, model | |
if file.filename.endswith(".csv"): | |
df = pd.read_csv(file.file) | |
elif file.filename.endswith((".xls", ".xlsx")): | |
df = pd.read_excel(file.file) | |
else: | |
return {"error": "Unsupported file format."} | |
if "Question" not in df.columns or "Answer" not in df.columns: | |
return {"error": "Columns 'Question' and 'Answer' required."} | |
questions = df["Question"].tolist() | |
answers = df["Answer"].tolist() | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
question_embeddings = model.encode(questions) | |
index = faiss.IndexFlatL2(question_embeddings.shape[1]) | |
index.add(np.array(question_embeddings).astype('float32')) | |
return {"status": "Knowledge base loaded", "total_questions": len(questions)} | |
async def clear_data(): | |
global questions, answers, index, model | |
questions, answers, index, model = [], [], None, None | |
return {"status": "Knowledge base cleared"} | |
async def ask_question(question: Question): | |
if not index: | |
return {"answer": "Knowledge base not loaded"} | |
query_embedding = model.encode([question.query]).astype('float32') | |
D, I = index.search(query_embedding, k=1) | |
if D[0][0] < 50: # Distance threshold | |
matched_answer = answers[I[0][0]] | |
return {"answer": matched_answer} | |
else: | |
return {"answer": "I don’t have an answer for that."} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |