pranavupadhyaya52's picture
Update app.py
a0bb1e3 verified
# -*- coding: utf-8 -*-
"""notebook9de6b64a65 (4).ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1c2KnGGRls-UP7uEOg2uV-FX5EjJkhtte
"""
# Commented out IPython magic to ensure Python compatibility.
# %%capture output
# %pip install unsloth
# %pip install -qU "langchain-chroma>=0.1.2" langchain-huggingface langchain-core
# %pip install -U gradio pillow datasets
# %pip install assemblyai PyMuPDF
from huggingface_hub import snapshot_download
import gradio as gr
import assemblyai as aai
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import chromadb
from unsloth import FastVisionModel # FastLanguageModel for LLMs
import torch
from langchain_core.vectorstores import InMemoryVectorStore
from transformers import TextIteratorStreamer
from PIL import Image
snapshot_download(repo_id="pranavupadhyaya52/lavita-MedQuAD-embeddings", repo_type="dataset", local_dir="./chroma_langchain_db")
aai.settings.api_key = "c50e769cd99c43509c13bd6226645a2c"
config = aai.TranscriptionConfig(speech_model=aai.SpeechModel.best)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda:0"})
vector_store = Chroma(
collection_name="mediwiki_rag",
embedding_function=embeddings,
persist_directory="./chroma_langchain_db", # Where to save data locally, remove if not necessary
)
persistent_client = chromadb.PersistentClient()
collection = persistent_client.get_or_create_collection("collection_name")
retriever = vector_store.as_retriever(
)
"""
from datasets import load_dataset
from langchain_core.documents import Document
data = load_dataset("lavita/MedQuAD")
vector_store.add_documents(documents=[Document(page_content=str(i['answer']),metadata={"source":i['document_source']}, ) for k, i in zip(range(41000), data["train"])])
"""
model, tokenizer = FastVisionModel.from_pretrained(
"pranavupadhyaya52/llama-3.2-11b-vision-instruct-mediwiki3",
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
use_gradient_checkpointing = "unsloth",
)
similarity_store = InMemoryVectorStore(embeddings)
class unsloth_agent:
def __init__(self, model, tokenizer, sys_prompt, name):
self.name = name
self.model, self.tokenizer = model, tokenizer
self.sys_prompt = sys_prompt
def no_image_prompt(self, prompt):
message = tokenizer.apply_chat_template([
{"role": "user", "content": [
{"type": "text", "text": self.sys_prompt}
]},
{"role": "user", "content": [
{"type": "text", "text": prompt}
]}
], add_generation_prompt = True)
return message
def yes_image_prompt(self, prompt):
message = tokenizer.apply_chat_template([
{"role": "user", "content": [
{"type": "text", "text": self.sys_prompt}
]},
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": prompt}
]}
], add_generation_prompt = True)
return message
class unsloth_agent_supervisor:
def __init__(self, agents:list, user_input:str, file=None):
if user_input == None and file == None:
user_input = "prompt user for input"
elif type(file).__name__ == "list":
user_input = "Tell user that only one multimedia object is allowed."
else:
pass
_agents, _agent_name = [str(x.name) for x in agents], ""
current_agent = None
if file!= None:
_agent_name = self.similarity_finder(_agents, file[len(file)-3:])
else:
_agent_name = self.similarity_finder(_agents, "text")
for i in agents:
if i.name == _agent_name:
current_agent = i
else:
pass
if file == None and user_input != None:
image = None
message = current_agent.no_image_prompt(user_input)
elif str(file[(len(file)-3):]) in ["mp3", "wav"]:
image = None
input_text = aai.Transcriber(config=config).transcribe(file)
message = current_agent.no_image_prompt(input_text.text)
elif str(file[(len(file)-3):]) in ["jpg", "peg", "png", "bmp"]:
image = Image.open(file)
message = current_agent.yes_image_prompt(user_input)
else:
image = None
message = current_agent.no_image_prompt("Prompt the user to enter atleast one input.")
inputs = current_agent.tokenizer(
image,
message,
add_special_tokens = False,
return_tensors = "pt",
).to("cuda")
text_streamer = TextIteratorStreamer(current_agent.tokenizer, skip_prompt = True)
_ = current_agent.model.generate(**inputs, streamer=text_streamer, max_new_tokens=128, use_cache=True, temperature=1.5, min_p=0.1)
self.streamer = text_streamer
def similarity_finder(self, keywords: list, sentence: str):
return_keyword = ""
similarity_store.add_texts(keywords)
return_keyword = similarity_store.similarity_search(sentence, k=1)
similarity_store.delete()
return return_keyword[0].page_content
text_agent = unsloth_agent(model=model,
tokenizer=tokenizer,
sys_prompt="You are a medical assistant. Answer the query in two sentences or less. Also, please put a disclaimer in the end that this does not construe medical", name="text_agent")
image_agent = unsloth_agent(model=model,
tokenizer=tokenizer,
sys_prompt="You are a medical assistant. Describe the image in five sentences or less. Also, please put a disclaimer in the end that this does not constitute medical advice",name="image_agent")
audio_agent = unsloth_agent(model=model,
tokenizer=tokenizer,
sys_prompt="You are a medical assistant. Answer the query in two sentences or less. Also, please put a disclaimer in the end that this does not constitute medical advice",name="audio_agent")
def gradio_chat(messages, history):
if len(messages["files"]) == 0 and len(messages["text"]) == 0:
return "Please enter a valid input"
elif len(messages["files"]) == 0 or messages["files"][0][(len(messages["files"][0])-3):] in ["mp3", "wav"]:
output_text, input_prompt, input_file = "", "", None
if len(messages["files"])==0:
input_prompt = f"""{messages["text"]}, context : {retriever.invoke(messages["text"])}"""
input_file = None
elif len(messages["text"]) == 0:
input_prompt = None
input_file = messages["files"][0]
else:
input_prompt = f"""{messages["text"]}, context : {retriever.invoke(messages["text"])}"""
input_file = messages["files"][0]
supervisor_agent = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], input_prompt, input_file)
for chat in supervisor_agent.streamer:
output_text += chat
context = retriever.invoke(output_text)
return output_text
elif len(messages["text"]) == 0 or messages["files"][0][(len(messages["files"][0])-3):] in ["jpg", "peg", "png", "bmp"]:
output_text, final_text = "", ""
supervisor_agent = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], messages["text"], file=messages["files"][0])
for chat in supervisor_agent.streamer:
output_text += chat
context = retriever.invoke(output_text)
final_supervisor = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], f"{output_text} context={context[0].page_content}")
for final_chat in final_supervisor.streamer:
final_text += final_chat
return final_text
else:
return "Invalid Input"
app = gr.ChatInterface(fn=gradio_chat, type="messages", title="Medical Assistant", multimodal=True)
if __name__ == "__main__":
app.launch()