Spaces:
Runtime error
Runtime error
# build_faiss_index.py | |
import faiss | |
import pickle | |
import numpy as np | |
from datasets import load_dataset | |
from langchain.embeddings import HuggingFaceEmbeddings # Or your embedding class | |
def format_entry(entry): | |
return { | |
"question": entry["question"], | |
"formatted": ( | |
f"Q: {entry['question']}\n" | |
f"A. {entry['opa']} B. {entry['opb']} C. {entry['opc']} D. {entry['opd']}\n" | |
f"Correct Answer: {entry['cop']}\n" | |
f"Explanation: {entry['exp']}" | |
) | |
} | |
print("Loading MedMCQA dataset (5000 rows)...") | |
dataset = load_dataset("medmcqa", split="train[:5000]") | |
formatted_data = [format_entry(entry) for entry in dataset] | |
# Extract questions for embeddings | |
questions = [entry["formatted"] for entry in formatted_data] | |
print("Generating embeddings...") | |
embeddings_model = HuggingFaceEmbeddings() # Or your specific embeddings | |
vectors = np.array([embeddings_model.embed_query(q) for q in questions], dtype="float32") | |
print("Building FAISS index...") | |
index = faiss.IndexFlatL2(vectors.shape[1]) | |
index.add(vectors) | |
# Save FAISS index and data | |
faiss.write_index(index, "data/medmcqa_index/index.faiss") | |
with open("data/medmcqa_index/index.pkl", "wb") as f: | |
pickle.dump(formatted_data, f) | |
print("FAISS index saved successfully!") | |