Spaces:
Build error
Build error
import os | |
import argparse | |
from datasets import load_dataset | |
from langchain.schema import Document | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.llms import LlamaCpp | |
from langchain.chains import RetrievalQA | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
# Initialize the database | |
def initialize_database(): | |
print("๐น Loading medical dataset...") | |
ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train") | |
qa_pairs = [{"question": x["instruction"], "answer": x["output"]} for x in ds.select(range(1000))] | |
# Convert to LangChain Documents | |
print("๐น Converting to LangChain documents...") | |
docs = [ | |
Document( | |
page_content=f"Question: {item['question']}\nAnswer: {item['answer']}", | |
metadata={"source": "ChatDoctor"} | |
) | |
for item in qa_pairs | |
] | |
# Embedding documents | |
print("๐น Embedding documents...") | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# ChromaDB setup | |
persist_dir = "./chroma_medical_db" | |
if not os.path.exists(persist_dir): | |
print("๐น Creating new ChromaDB...") | |
vectorstore = Chroma.from_documents(docs, embedding_model, persist_directory=persist_dir) | |
vectorstore.persist() | |
else: | |
print("๐น Loading existing ChromaDB...") | |
vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embedding_model) | |
# Setup the retriever | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
# Local LLM setup | |
print("๐น Loading local LLM model...") | |
llm = LlamaCpp( | |
model_path="models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", | |
n_ctx=1024, | |
temperature=0.7, | |
max_tokens=512, | |
streaming=True, | |
callbacks=[StreamingStdOutCallbackHandler()], | |
verbose=True, | |
f16_kv=True, | |
use_mlock=True, | |
use_mmap=True, | |
n_threads=4, | |
n_batch=64 | |
) | |
# Build RAG QA chain | |
print("๐น Building RAG chain...") | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=retriever, | |
return_source_documents=True | |
) | |
return qa_chain | |
# Function to handle the query | |
def handle_query(query): | |
qa_chain = initialize_database() | |
print(f"๐น Query: {query}") | |
result = qa_chain(query) | |
response = { | |
"answer": result['result'], | |
"sources": result['source_documents'] | |
} | |
return response | |
# Main CLI functionality | |
def main(): | |
parser = argparse.ArgumentParser(description="Medical Question-Answering CLI Application") | |
parser.add_argument("query", type=str, help="Query to ask the medical AI agent") | |
args = parser.parse_args() | |
query = args.query | |
result = handle_query(query) | |
print("\n๐ง Answer:") | |
print(result["answer"]) | |
print("\nSource Documents:") | |
for doc in result["sources"]: | |
print(doc["text"]) | |
if __name__ == "__main__": | |
main() | |