import os import argparse from datasets import load_dataset from langchain.schema import Document from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings from langchain.llms import LlamaCpp from langchain.chains import RetrievalQA from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # Initialize the database def initialize_database(): print("šŸ”¹ Loading medical dataset...") ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train") qa_pairs = [{"question": x["instruction"], "answer": x["output"]} for x in ds.select(range(1000))] # Convert to LangChain Documents print("šŸ”¹ Converting to LangChain documents...") docs = [ Document( page_content=f"Question: {item['question']}\nAnswer: {item['answer']}", metadata={"source": "ChatDoctor"} ) for item in qa_pairs ] # Embedding documents print("šŸ”¹ Embedding documents...") embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") # ChromaDB setup persist_dir = "./chroma_medical_db" if not os.path.exists(persist_dir): print("šŸ”¹ Creating new ChromaDB...") vectorstore = Chroma.from_documents(docs, embedding_model, persist_directory=persist_dir) vectorstore.persist() else: print("šŸ”¹ Loading existing ChromaDB...") vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embedding_model) # Setup the retriever retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) # Local LLM setup print("šŸ”¹ Loading local LLM model...") llm = LlamaCpp( model_path="models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", n_ctx=1024, temperature=0.7, max_tokens=512, streaming=True, callbacks=[StreamingStdOutCallbackHandler()], verbose=True, f16_kv=True, use_mlock=True, use_mmap=True, n_threads=4, n_batch=64 ) # Build RAG QA chain print("šŸ”¹ Building RAG chain...") qa_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, return_source_documents=True ) return qa_chain # Function to handle the query def handle_query(query): qa_chain = initialize_database() print(f"šŸ”¹ Query: {query}") result = qa_chain(query) response = { "answer": result['result'], "sources": result['source_documents'] } return response # Main CLI functionality def main(): parser = argparse.ArgumentParser(description="Medical Question-Answering CLI Application") parser.add_argument("query", type=str, help="Query to ask the medical AI agent") args = parser.parse_args() query = args.query result = handle_query(query) print("\n🧠 Answer:") print(result["answer"]) print("\nSource Documents:") for doc in result["sources"]: print(doc["text"]) if __name__ == "__main__": main()