Spaces:
Build error
Build error
File size: 3,062 Bytes
64942a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import argparse
from datasets import load_dataset
from langchain.schema import Document
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import LlamaCpp
from langchain.chains import RetrievalQA
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# Initialize the database
def initialize_database():
print("๐น Loading medical dataset...")
ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train")
qa_pairs = [{"question": x["instruction"], "answer": x["output"]} for x in ds.select(range(1000))]
# Convert to LangChain Documents
print("๐น Converting to LangChain documents...")
docs = [
Document(
page_content=f"Question: {item['question']}\nAnswer: {item['answer']}",
metadata={"source": "ChatDoctor"}
)
for item in qa_pairs
]
# Embedding documents
print("๐น Embedding documents...")
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# ChromaDB setup
persist_dir = "./chroma_medical_db"
if not os.path.exists(persist_dir):
print("๐น Creating new ChromaDB...")
vectorstore = Chroma.from_documents(docs, embedding_model, persist_directory=persist_dir)
vectorstore.persist()
else:
print("๐น Loading existing ChromaDB...")
vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embedding_model)
# Setup the retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# Local LLM setup
print("๐น Loading local LLM model...")
llm = LlamaCpp(
model_path="models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
n_ctx=1024,
temperature=0.7,
max_tokens=512,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
verbose=True,
f16_kv=True,
use_mlock=True,
use_mmap=True,
n_threads=4,
n_batch=64
)
# Build RAG QA chain
print("๐น Building RAG chain...")
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
return_source_documents=True
)
return qa_chain
# Function to handle the query
def handle_query(query):
qa_chain = initialize_database()
print(f"๐น Query: {query}")
result = qa_chain(query)
response = {
"answer": result['result'],
"sources": result['source_documents']
}
return response
# Main CLI functionality
def main():
parser = argparse.ArgumentParser(description="Medical Question-Answering CLI Application")
parser.add_argument("query", type=str, help="Query to ask the medical AI agent")
args = parser.parse_args()
query = args.query
result = handle_query(query)
print("\n๐ง Answer:")
print(result["answer"])
print("\nSource Documents:")
for doc in result["sources"]:
print(doc["text"])
if __name__ == "__main__":
main()
|