File size: 2,136 Bytes
1bb7c15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss

app = FastAPI(title="Closed-Domain Q&A Chatbot")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Data store
questions = []
answers = []
index = None
model = None

class Question(BaseModel):
    query: str

@app.post("/load")
async def load_qa(file: UploadFile = File(...)):
    global questions, answers, index, model

    if file.filename.endswith(".csv"):
        df = pd.read_csv(file.file)
    elif file.filename.endswith((".xls", ".xlsx")):
        df = pd.read_excel(file.file)
    else:
        return {"error": "Unsupported file format."}

    if "Question" not in df.columns or "Answer" not in df.columns:
        return {"error": "Columns 'Question' and 'Answer' required."}

    questions = df["Question"].tolist()
    answers = df["Answer"].tolist()

    model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    question_embeddings = model.encode(questions)

    index = faiss.IndexFlatL2(question_embeddings.shape[1])
    index.add(np.array(question_embeddings).astype('float32'))

    return {"status": "Knowledge base loaded", "total_questions": len(questions)}

@app.post("/clear")
async def clear_data():
    global questions, answers, index, model
    questions, answers, index, model = [], [], None, None
    return {"status": "Knowledge base cleared"}

@app.post("/ask")
async def ask_question(question: Question):
    if not index:
        return {"answer": "Knowledge base not loaded"}

    query_embedding = model.encode([question.query]).astype('float32')
    D, I = index.search(query_embedding, k=1)

    if D[0][0] < 50:  # Distance threshold
        matched_answer = answers[I[0][0]]
        return {"answer": matched_answer}
    else:
        return {"answer": "I don’t have an answer for that."}
        
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)