SankethHonavar commited on
Commit
2e37c9b
·
1 Parent(s): 62d3203

Add prebuilt FAISS index and loader for Hugging Face Space

Browse files
Files changed (2) hide show
  1. build_faiss_index.py +39 -0
  2. dataset_loader.py +9 -16
build_faiss_index.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # build_faiss_index.py
2
+ import faiss
3
+ import pickle
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+ from langchain.embeddings import HuggingFaceEmbeddings # Or your embedding class
7
+
8
+ def format_entry(entry):
9
+ return {
10
+ "question": entry["question"],
11
+ "formatted": (
12
+ f"Q: {entry['question']}\n"
13
+ f"A. {entry['opa']} B. {entry['opb']} C. {entry['opc']} D. {entry['opd']}\n"
14
+ f"Correct Answer: {entry['cop']}\n"
15
+ f"Explanation: {entry['exp']}"
16
+ )
17
+ }
18
+
19
+ print("Loading MedMCQA dataset (5000 rows)...")
20
+ dataset = load_dataset("medmcqa", split="train[:5000]")
21
+ formatted_data = [format_entry(entry) for entry in dataset]
22
+
23
+ # Extract questions for embeddings
24
+ questions = [entry["formatted"] for entry in formatted_data]
25
+
26
+ print("Generating embeddings...")
27
+ embeddings_model = HuggingFaceEmbeddings() # Or your specific embeddings
28
+ vectors = np.array([embeddings_model.embed_query(q) for q in questions], dtype="float32")
29
+
30
+ print("Building FAISS index...")
31
+ index = faiss.IndexFlatL2(vectors.shape[1])
32
+ index.add(vectors)
33
+
34
+ # Save FAISS index and data
35
+ faiss.write_index(index, "data/medmcqa_index/index.faiss")
36
+ with open("data/medmcqa_index/index.pkl", "wb") as f:
37
+ pickle.dump(formatted_data, f)
38
+
39
+ print("FAISS index saved successfully!")
dataset_loader.py CHANGED
@@ -1,18 +1,11 @@
1
  # dataset_loader.py
2
- from datasets import load_dataset
 
3
 
4
- def load_medmcqa_subset(limit=5000):
5
- dataset = load_dataset("medmcqa", split=f"train[:{limit}]")
6
-
7
- def format_entry(entry):
8
- return {
9
- "question": entry["question"],
10
- "formatted": (
11
- f"Q: {entry['question']}\n"
12
- f"A. {entry['opa']} B. {entry['opb']} C. {entry['opc']} D. {entry['opd']}\n"
13
- f"Correct Answer: {entry['cop']}\n"
14
- f"Explanation: {entry['exp']}"
15
- )
16
- }
17
-
18
- return [format_entry(entry) for entry in dataset]
 
1
  # dataset_loader.py
2
+ import pickle
3
+ import faiss
4
 
5
+ def load_faiss_index():
6
+ print("Loading FAISS index...")
7
+ index = faiss.read_index("data/medmcqa_index/index.faiss")
8
+ with open("data/medmcqa_index/index.pkl", "rb") as f:
9
+ formatted_data = pickle.load(f)
10
+ print("FAISS index loaded successfully!")
11
+ return index, formatted_data