|
from indexer import ( |
|
create_vector_database, |
|
get_llm, |
|
get_prompt_template, |
|
) |
|
import gradio as gr |
|
|
|
|
|
def format_contexts(contexts): |
|
return "\n".join( |
|
[ |
|
f"Reference {i+1}:\n{doc.metadata['question']}\n{doc.metadata['answer']}" |
|
for i, doc in enumerate(contexts) |
|
] |
|
) |
|
|
|
|
|
class CustomRAG: |
|
def __init__(self, vector_db, llm, prompt_template): |
|
self.vector_db = vector_db |
|
self.llm = llm |
|
self.prompt_template = prompt_template |
|
|
|
def run(self, query): |
|
retriever = self.vector_db.as_retriever(search_kwargs={"k": 3}) |
|
contexts = retriever.invoke(query) |
|
formatted_context = format_contexts(contexts) |
|
prompt = self.prompt_template.format(context=formatted_context, question=query) |
|
return self.llm.invoke(prompt), contexts |
|
|
|
|
|
def answer_question(query): |
|
llm = get_llm("google/flan-t5-base") |
|
|
|
vector_database = create_vector_database("sentence-transformers/all-MiniLM-L6-v2") |
|
prompt_template = get_prompt_template() |
|
rag = CustomRAG( |
|
vector_database, |
|
llm, |
|
prompt_template, |
|
) |
|
response, _ = rag.run(query) |
|
|
|
return response |
|
|
|
|
|
demo = gr.Interface( |
|
fn=answer_question, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Describe your medical concern", |
|
placeholder="e.g. I've been feeling tired and dizzy lately.", |
|
lines=3, |
|
), |
|
], |
|
outputs="text", |
|
title="Medical Assistant β RAG", |
|
description=( |
|
"Get helpful insights based on your described symptoms. " |
|
"This assistant uses medical reference data to provide informative responses. " |
|
"Note: This is not a substitute for professional medical advice." |
|
), |
|
) |
|
|
|
demo.launch() |
|
|