medical-llm-chatbot / build_faiss_index.py
SankethHonavar's picture
Add prebuilt FAISS index and loader for Hugging Face Space
2e37c9b
# build_faiss_index.py
import faiss
import pickle
import numpy as np
from datasets import load_dataset
from langchain.embeddings import HuggingFaceEmbeddings # Or your embedding class
def format_entry(entry):
return {
"question": entry["question"],
"formatted": (
f"Q: {entry['question']}\n"
f"A. {entry['opa']} B. {entry['opb']} C. {entry['opc']} D. {entry['opd']}\n"
f"Correct Answer: {entry['cop']}\n"
f"Explanation: {entry['exp']}"
)
}
print("Loading MedMCQA dataset (5000 rows)...")
dataset = load_dataset("medmcqa", split="train[:5000]")
formatted_data = [format_entry(entry) for entry in dataset]
# Extract questions for embeddings
questions = [entry["formatted"] for entry in formatted_data]
print("Generating embeddings...")
embeddings_model = HuggingFaceEmbeddings() # Or your specific embeddings
vectors = np.array([embeddings_model.embed_query(q) for q in questions], dtype="float32")
print("Building FAISS index...")
index = faiss.IndexFlatL2(vectors.shape[1])
index.add(vectors)
# Save FAISS index and data
faiss.write_index(index, "data/medmcqa_index/index.faiss")
with open("data/medmcqa_index/index.pkl", "wb") as f:
pickle.dump(formatted_data, f)
print("FAISS index saved successfully!")