medical / app.py
Amir230703's picture
Update app.py
2715ee7 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load model and tokenizer with optimizations
model_name = "Amir230703/phi3-medmcqa-finetuned"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2" # Faster attention
).eval()
# Use faster kernels if available
if torch.cuda.is_available():
model = torch.compile(model)
def generate_answer(question):
# Create structured prompt
prompt = f"""Instruction: Answer the following medical question concisely.
Question: {question}
Answer:"""
# Tokenize with optimized settings
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
).to(model.device)
# Generate with optimized parameters
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150, # Reduced from 200
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.1, # Prevent repetition
num_return_sequences=1
)
# Decode and clean output
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer.split("Answer:")[-1].strip()
# Gradio interface with queueing
demo = gr.Interface(
fn=generate_answer,
inputs=gr.Textbox(placeholder="Enter your medical question...", lines=3),
outputs=gr.Textbox(label="Answer"),
title="Medical QA Assistant",
description="AI-powered medical question answering. Please be specific in your queries.",
allow_flagging="never"
)
# Launch with performance settings
demo.launch(
server_name="0.0.0.0" if torch.cuda.is_available() else None,
share=False,
max_threads=2
)