File size: 3,062 Bytes
64942a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import argparse
from datasets import load_dataset
from langchain.schema import Document
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import LlamaCpp
from langchain.chains import RetrievalQA
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

# Initialize the database
def initialize_database():
    print("๐Ÿ”น Loading medical dataset...")
    ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train")
    qa_pairs = [{"question": x["instruction"], "answer": x["output"]} for x in ds.select(range(1000))]

    # Convert to LangChain Documents
    print("๐Ÿ”น Converting to LangChain documents...")
    docs = [
        Document(
            page_content=f"Question: {item['question']}\nAnswer: {item['answer']}",
            metadata={"source": "ChatDoctor"}
        )
        for item in qa_pairs
    ]
    
    # Embedding documents
    print("๐Ÿ”น Embedding documents...")
    embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

    # ChromaDB setup
    persist_dir = "./chroma_medical_db"
    if not os.path.exists(persist_dir):
        print("๐Ÿ”น Creating new ChromaDB...")
        vectorstore = Chroma.from_documents(docs, embedding_model, persist_directory=persist_dir)
        vectorstore.persist()
    else:
        print("๐Ÿ”น Loading existing ChromaDB...")
        vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embedding_model)

    # Setup the retriever
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

    # Local LLM setup
    print("๐Ÿ”น Loading local LLM model...")
    llm = LlamaCpp(
        model_path="models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
        n_ctx=1024,
        temperature=0.7,
        max_tokens=512,
        streaming=True,
        callbacks=[StreamingStdOutCallbackHandler()],
        verbose=True,
        f16_kv=True,
        use_mlock=True,
        use_mmap=True,
        n_threads=4,
        n_batch=64
    )

    # Build RAG QA chain
    print("๐Ÿ”น Building RAG chain...")
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=retriever,
        return_source_documents=True
    )

    return qa_chain

# Function to handle the query
def handle_query(query):
    qa_chain = initialize_database()
    print(f"๐Ÿ”น Query: {query}")
    result = qa_chain(query)
    response = {
        "answer": result['result'],
        "sources": result['source_documents']
    }
    return response

# Main CLI functionality
def main():
    parser = argparse.ArgumentParser(description="Medical Question-Answering CLI Application")
    parser.add_argument("query", type=str, help="Query to ask the medical AI agent")

    args = parser.parse_args()
    query = args.query

    result = handle_query(query)
    print("\n๐Ÿง  Answer:")
    print(result["answer"])
    print("\nSource Documents:")
    for doc in result["sources"]:
        print(doc["text"])

if __name__ == "__main__":
    main()