Alexvatti commited on
Commit
1bb7c15
·
verified ·
1 Parent(s): 873cdae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer
6
+ import numpy as np
7
+ import faiss
8
+
9
+ app = FastAPI(title="Closed-Domain Q&A Chatbot")
10
+
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ # Data store
19
+ questions = []
20
+ answers = []
21
+ index = None
22
+ model = None
23
+
24
+ class Question(BaseModel):
25
+ query: str
26
+
27
+ @app.post("/load")
28
+ async def load_qa(file: UploadFile = File(...)):
29
+ global questions, answers, index, model
30
+
31
+ if file.filename.endswith(".csv"):
32
+ df = pd.read_csv(file.file)
33
+ elif file.filename.endswith((".xls", ".xlsx")):
34
+ df = pd.read_excel(file.file)
35
+ else:
36
+ return {"error": "Unsupported file format."}
37
+
38
+ if "Question" not in df.columns or "Answer" not in df.columns:
39
+ return {"error": "Columns 'Question' and 'Answer' required."}
40
+
41
+ questions = df["Question"].tolist()
42
+ answers = df["Answer"].tolist()
43
+
44
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
45
+ question_embeddings = model.encode(questions)
46
+
47
+ index = faiss.IndexFlatL2(question_embeddings.shape[1])
48
+ index.add(np.array(question_embeddings).astype('float32'))
49
+
50
+ return {"status": "Knowledge base loaded", "total_questions": len(questions)}
51
+
52
+ @app.post("/clear")
53
+ async def clear_data():
54
+ global questions, answers, index, model
55
+ questions, answers, index, model = [], [], None, None
56
+ return {"status": "Knowledge base cleared"}
57
+
58
+ @app.post("/ask")
59
+ async def ask_question(question: Question):
60
+ if not index:
61
+ return {"answer": "Knowledge base not loaded"}
62
+
63
+ query_embedding = model.encode([question.query]).astype('float32')
64
+ D, I = index.search(query_embedding, k=1)
65
+
66
+ if D[0][0] < 50: # Distance threshold
67
+ matched_answer = answers[I[0][0]]
68
+ return {"answer": matched_answer}
69
+ else:
70
+ return {"answer": "I don’t have an answer for that."}
71
+
72
+ if __name__ == "__main__":
73
+ uvicorn.run(app, host="0.0.0.0", port=7860)