prakhardoneria commited on
Commit
5cba5a1
·
verified ·
1 Parent(s): 31386f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -8,16 +8,16 @@ from chromadb.config import Settings
8
  from transformers import pipeline
9
 
10
  # Device setup
11
- device = -1 # Force CPU use
12
  print("Device set to: CPU")
13
 
14
  # Load CSV data
15
- df = pd.read_csv("iec_college_data.csv").dropna(subset=["content"]).reset_index(drop=True)
16
 
17
  # Load embedding model on CPU
18
  embed_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
19
 
20
- # ChromaDB - use updated client format
21
  chroma_client = chromadb.PersistentClient(path="./chroma_db")
22
  collection_name = "iec_data"
23
 
@@ -45,8 +45,8 @@ if collection.count() == 0:
45
  print(f"Indexed {idx}/{len(df)}")
46
  print("Indexing complete.")
47
 
48
- # QA model: Use lighter model on CPU
49
- qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-small", device=device)
50
 
51
  # QA function
52
  def answer_question(user_question):
@@ -55,11 +55,10 @@ def answer_question(user_question):
55
  context = "\n".join(results["documents"][0])
56
  if len(context.split()) > 400:
57
  context = " ".join(context.split()[:400])
58
- prompt = f"You are an assistant for IEC College. Use the info below.\n\nContext:\n{context}\n\nQuestion: {user_question}\nAnswer:"
59
- result = qa_pipeline(prompt, max_new_tokens=200)[0]["generated_text"]
60
- return result.strip()
61
 
62
- # Gradio interface
63
  iface = gr.Interface(
64
  fn=answer_question,
65
  inputs=gr.Textbox(lines=2, placeholder="Ask about IEC College..."),
 
8
  from transformers import pipeline
9
 
10
  # Device setup
11
+ device = -1 # Use CPU
12
  print("Device set to: CPU")
13
 
14
  # Load CSV data
15
+ df = pd.read_csv("/mnt/data/iec_college_data.csv").dropna(subset=["content"]).reset_index(drop=True)
16
 
17
  # Load embedding model on CPU
18
  embed_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
19
 
20
+ # ChromaDB setup
21
  chroma_client = chromadb.PersistentClient(path="./chroma_db")
22
  collection_name = "iec_data"
23
 
 
45
  print(f"Indexed {idx}/{len(df)}")
46
  print("Indexing complete.")
47
 
48
+ # Use lightweight extractive QA model
49
+ qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased", device=device)
50
 
51
  # QA function
52
  def answer_question(user_question):
 
55
  context = "\n".join(results["documents"][0])
56
  if len(context.split()) > 400:
57
  context = " ".join(context.split()[:400])
58
+ result = qa_pipeline(question=user_question, context=context)
59
+ return result["answer"]
 
60
 
61
+ # Gradio UI
62
  iface = gr.Interface(
63
  fn=answer_question,
64
  inputs=gr.Textbox(lines=2, placeholder="Ask about IEC College..."),