|
|
|
"""notebook9de6b64a65 (4).ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1c2KnGGRls-UP7uEOg2uV-FX5EjJkhtte |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
import gradio as gr |
|
import assemblyai as aai |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_chroma import Chroma |
|
import chromadb |
|
from unsloth import FastVisionModel |
|
import torch |
|
from langchain_core.vectorstores import InMemoryVectorStore |
|
from transformers import TextIteratorStreamer |
|
from PIL import Image |
|
|
|
|
|
snapshot_download(repo_id="pranavupadhyaya52/lavita-MedQuAD-embeddings", repo_type="dataset", local_dir="./chroma_langchain_db") |
|
|
|
aai.settings.api_key = "c50e769cd99c43509c13bd6226645a2c" |
|
|
|
config = aai.TranscriptionConfig(speech_model=aai.SpeechModel.best) |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda:0"}) |
|
|
|
vector_store = Chroma( |
|
collection_name="mediwiki_rag", |
|
embedding_function=embeddings, |
|
persist_directory="./chroma_langchain_db", |
|
) |
|
|
|
persistent_client = chromadb.PersistentClient() |
|
collection = persistent_client.get_or_create_collection("collection_name") |
|
|
|
retriever = vector_store.as_retriever( |
|
) |
|
|
|
""" |
|
from datasets import load_dataset |
|
from langchain_core.documents import Document |
|
|
|
data = load_dataset("lavita/MedQuAD") |
|
|
|
vector_store.add_documents(documents=[Document(page_content=str(i['answer']),metadata={"source":i['document_source']}, ) for k, i in zip(range(41000), data["train"])]) |
|
""" |
|
|
|
model, tokenizer = FastVisionModel.from_pretrained( |
|
"pranavupadhyaya52/llama-3.2-11b-vision-instruct-mediwiki3", |
|
load_in_4bit = True, |
|
use_gradient_checkpointing = "unsloth", |
|
) |
|
|
|
similarity_store = InMemoryVectorStore(embeddings) |
|
|
|
|
|
class unsloth_agent: |
|
def __init__(self, model, tokenizer, sys_prompt, name): |
|
self.name = name |
|
self.model, self.tokenizer = model, tokenizer |
|
self.sys_prompt = sys_prompt |
|
|
|
def no_image_prompt(self, prompt): |
|
message = tokenizer.apply_chat_template([ |
|
{"role": "user", "content": [ |
|
{"type": "text", "text": self.sys_prompt} |
|
]}, |
|
{"role": "user", "content": [ |
|
{"type": "text", "text": prompt} |
|
]} |
|
], add_generation_prompt = True) |
|
return message |
|
|
|
def yes_image_prompt(self, prompt): |
|
message = tokenizer.apply_chat_template([ |
|
{"role": "user", "content": [ |
|
{"type": "text", "text": self.sys_prompt} |
|
]}, |
|
{"role": "user", "content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": prompt} |
|
]} |
|
], add_generation_prompt = True) |
|
return message |
|
|
|
class unsloth_agent_supervisor: |
|
|
|
def __init__(self, agents:list, user_input:str, file=None): |
|
if user_input == None and file == None: |
|
user_input = "prompt user for input" |
|
elif type(file).__name__ == "list": |
|
user_input = "Tell user that only one multimedia object is allowed." |
|
else: |
|
pass |
|
|
|
_agents, _agent_name = [str(x.name) for x in agents], "" |
|
current_agent = None |
|
|
|
if file!= None: |
|
_agent_name = self.similarity_finder(_agents, file[len(file)-3:]) |
|
else: |
|
_agent_name = self.similarity_finder(_agents, "text") |
|
|
|
|
|
for i in agents: |
|
if i.name == _agent_name: |
|
current_agent = i |
|
else: |
|
pass |
|
|
|
|
|
if file == None and user_input != None: |
|
image = None |
|
message = current_agent.no_image_prompt(user_input) |
|
elif str(file[(len(file)-3):]) in ["mp3", "wav"]: |
|
image = None |
|
input_text = aai.Transcriber(config=config).transcribe(file) |
|
message = current_agent.no_image_prompt(input_text.text) |
|
elif str(file[(len(file)-3):]) in ["jpg", "peg", "png", "bmp"]: |
|
image = Image.open(file) |
|
message = current_agent.yes_image_prompt(user_input) |
|
else: |
|
image = None |
|
message = current_agent.no_image_prompt("Prompt the user to enter atleast one input.") |
|
|
|
|
|
inputs = current_agent.tokenizer( |
|
image, |
|
message, |
|
add_special_tokens = False, |
|
return_tensors = "pt", |
|
).to("cuda") |
|
|
|
text_streamer = TextIteratorStreamer(current_agent.tokenizer, skip_prompt = True) |
|
_ = current_agent.model.generate(**inputs, streamer=text_streamer, max_new_tokens=128, use_cache=True, temperature=1.5, min_p=0.1) |
|
self.streamer = text_streamer |
|
|
|
def similarity_finder(self, keywords: list, sentence: str): |
|
|
|
return_keyword = "" |
|
similarity_store.add_texts(keywords) |
|
return_keyword = similarity_store.similarity_search(sentence, k=1) |
|
similarity_store.delete() |
|
return return_keyword[0].page_content |
|
|
|
text_agent = unsloth_agent(model=model, |
|
tokenizer=tokenizer, |
|
sys_prompt="You are a medical assistant. Answer the query in two sentences or less. Also, please put a disclaimer in the end that this does not construe medical", name="text_agent") |
|
|
|
image_agent = unsloth_agent(model=model, |
|
tokenizer=tokenizer, |
|
sys_prompt="You are a medical assistant. Describe the image in five sentences or less. Also, please put a disclaimer in the end that this does not constitute medical advice",name="image_agent") |
|
audio_agent = unsloth_agent(model=model, |
|
tokenizer=tokenizer, |
|
sys_prompt="You are a medical assistant. Answer the query in two sentences or less. Also, please put a disclaimer in the end that this does not constitute medical advice",name="audio_agent") |
|
|
|
def gradio_chat(messages, history): |
|
if len(messages["files"]) == 0 and len(messages["text"]) == 0: |
|
return "Please enter a valid input" |
|
|
|
elif len(messages["files"]) == 0 or messages["files"][0][(len(messages["files"][0])-3):] in ["mp3", "wav"]: |
|
output_text, input_prompt, input_file = "", "", None |
|
if len(messages["files"])==0: |
|
input_prompt = f"""{messages["text"]}, context : {retriever.invoke(messages["text"])}""" |
|
input_file = None |
|
elif len(messages["text"]) == 0: |
|
input_prompt = None |
|
input_file = messages["files"][0] |
|
else: |
|
input_prompt = f"""{messages["text"]}, context : {retriever.invoke(messages["text"])}""" |
|
input_file = messages["files"][0] |
|
supervisor_agent = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], input_prompt, input_file) |
|
for chat in supervisor_agent.streamer: |
|
output_text += chat |
|
context = retriever.invoke(output_text) |
|
return output_text |
|
|
|
elif len(messages["text"]) == 0 or messages["files"][0][(len(messages["files"][0])-3):] in ["jpg", "peg", "png", "bmp"]: |
|
output_text, final_text = "", "" |
|
supervisor_agent = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], messages["text"], file=messages["files"][0]) |
|
for chat in supervisor_agent.streamer: |
|
output_text += chat |
|
context = retriever.invoke(output_text) |
|
final_supervisor = unsloth_agent_supervisor([text_agent, image_agent, audio_agent], f"{output_text} context={context[0].page_content}") |
|
for final_chat in final_supervisor.streamer: |
|
final_text += final_chat |
|
return final_text |
|
|
|
else: |
|
return "Invalid Input" |
|
|
|
|
|
app = gr.ChatInterface(fn=gradio_chat, type="messages", title="Medical Assistant", multimodal=True) |
|
if __name__ == "__main__": |
|
app.launch() |