ndhaliwal59's picture
Update app.py
6ebfdb7 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer # Add this import at the top
from unsloth import FastLanguageModel
from peft import PeftModel
import logging
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""Load the model and tokenizer using Unsloth method"""
global model, tokenizer
try:
logger.info("Loading model with Unsloth...")
# Load the base model that you used for training
base_model_name = "unsloth/mistral-7b-bnb-4bit"
logger.info(f"Loading base model: {base_model_name}")
base_model, base_tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model_name,
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
logger.info("Base model loaded, now trying to load your merged model...")
# Since your model is merged, try to load it directly
try:
logger.info("Loading your merged model...")
model = AutoModelForCausalLM.from_pretrained(
"ndhaliwal59/mistral-7b-bloodwork-analysis-merged",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("ndhaliwal59/mistral-7b-bloodwork-analysis-merged")
logger.info("Your merged model loaded successfully!")
except Exception as merged_error:
logger.error(f"Failed to load merged model: {merged_error}")
logger.info("Using base model instead...")
model = base_model
tokenizer = base_tokenizer
# Set up tokenizer
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Model setup completed successfully!")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
logger.error("Trying fallback loading method...")
# Fallback: try loading with standard transformers
try:
logger.info("Loading fallback model for testing...")
model = AutoModelForCausalLM.from_pretrained(
"microsoft/DialoGPT-medium",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Fallback model loaded for testing!")
return True
except Exception as fallback_error:
logger.error(f"Fallback loading also failed: {fallback_error}")
return False
def generate_response(prompt, max_length=256, temperature=0.7, top_p=0.9):
"""Generate response from the model"""
global model, tokenizer
# Load model if not already loaded
if model is None or tokenizer is None:
logger.info("Model not loaded, attempting to load...")
if not load_model():
return "❌ Error: Failed to load model. Please check the logs for details."
try:
# Prepare the model for inference if it's an Unsloth model
try:
FastLanguageModel.for_inference(model)
except Exception as e:
logger.info(f"Not an Unsloth model or inference mode failed: {e}")
# If it's not an Unsloth model, continue normally
pass
# Format the prompt properly
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
# Tokenize input - fix the tokenizer processing
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=512,
padding=False # Changed from True to False
)
# Move inputs to the same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
logger.info(f"Input shape: {inputs['input_ids'].shape}")
logger.info(f"Model device: {device}")
# Generate response
with torch.no_grad():
outputs = model.generate(
input_ids=inputs['input_ids'], # Explicitly pass input_ids
attention_mask=inputs.get('attention_mask', None),
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Decode only the new tokens (skip the input prompt)
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
logger.error(f"Generation error: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
return f"❌ Error during generation: {str(e)}"
def get_model_info():
"""Get information about the loaded model"""
global model, tokenizer
if model is None:
return "No model loaded"
try:
model_type = type(model).__name__
device = str(next(model.parameters()).device)
return f"Model: {model_type}, Device: {device}"
except Exception as e:
return f"Model info error: {e}"
# Create Gradio interface
with gr.Blocks(title="🩺 Mistral-7B Bloodwork Analysis API") as demo:
gr.Markdown("""
# 🩺 Mistral-7B Bloodwork Analysis API
AI assistant for analyzing bloodwork results and medical data.
This interface also provides API access at `/api/predict`.
""")
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Medical Query",
placeholder="Describe the bloodwork results you'd like analyzed...",
lines=4,
max_lines=10
)
with gr.Row():
max_length_slider = gr.Slider(
minimum=50,
maximum=2500,
value=256,
step=10,
label="Max Response Length"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p"
)
with gr.Row():
submit_btn = gr.Button("🔬 Analyze", variant="primary")
clear_btn = gr.Button("🗑️ Clear")
with gr.Column(scale=2):
output_text = gr.Textbox(
label="Analysis Result",
lines=8,
max_lines=15,
show_copy_button=True
)
model_info = gr.Textbox(
label="Model Status",
value="Loading...",
interactive=False
)
with gr.Row():
gr.Examples(
examples=[
["What does an elevated white blood cell count indicate?", 200, 0.7, 0.9],
["Explain the significance of high cholesterol levels", 200, 0.7, 0.9],
["What are normal ranges for blood glucose?", 200, 0.7, 0.9],
["Patient has WBC: 15,000, RBC: 4.2, Hemoglobin: 12.5. What could this indicate?", 300, 0.7, 0.9],
["Interpret these liver function tests: ALT: 65, AST: 58, Bilirubin: 2.1", 250, 0.7, 0.9]
],
inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider],
outputs=output_text,
fn=generate_response,
cache_examples=False
)
# Event handlers
submit_btn.click(
fn=generate_response,
inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider],
outputs=output_text
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[prompt_input, output_text]
)
# Update model info when the interface loads
demo.load(
fn=get_model_info,
outputs=model_info
)
# Launch the interface
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)