|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from unsloth import FastLanguageModel |
|
from peft import PeftModel |
|
import logging |
|
import os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
def load_model(): |
|
"""Load the model and tokenizer using Unsloth method""" |
|
global model, tokenizer |
|
|
|
try: |
|
logger.info("Loading model with Unsloth...") |
|
|
|
|
|
base_model_name = "unsloth/mistral-7b-bnb-4bit" |
|
|
|
logger.info(f"Loading base model: {base_model_name}") |
|
base_model, base_tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=base_model_name, |
|
max_seq_length=2048, |
|
dtype=None, |
|
load_in_4bit=True, |
|
) |
|
|
|
logger.info("Base model loaded, now trying to load your merged model...") |
|
|
|
|
|
try: |
|
logger.info("Loading your merged model...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"ndhaliwal59/mistral-7b-bloodwork-analysis-merged", |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained("ndhaliwal59/mistral-7b-bloodwork-analysis-merged") |
|
logger.info("Your merged model loaded successfully!") |
|
|
|
except Exception as merged_error: |
|
logger.error(f"Failed to load merged model: {merged_error}") |
|
logger.info("Using base model instead...") |
|
model = base_model |
|
tokenizer = base_tokenizer |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
logger.info("Model setup completed successfully!") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model: {e}") |
|
logger.error("Trying fallback loading method...") |
|
|
|
|
|
try: |
|
logger.info("Loading fallback model for testing...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"microsoft/DialoGPT-medium", |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
logger.info("Fallback model loaded for testing!") |
|
return True |
|
|
|
except Exception as fallback_error: |
|
logger.error(f"Fallback loading also failed: {fallback_error}") |
|
return False |
|
|
|
def generate_response(prompt, max_length=256, temperature=0.7, top_p=0.9): |
|
"""Generate response from the model""" |
|
global model, tokenizer |
|
|
|
|
|
if model is None or tokenizer is None: |
|
logger.info("Model not loaded, attempting to load...") |
|
if not load_model(): |
|
return "❌ Error: Failed to load model. Please check the logs for details." |
|
|
|
try: |
|
|
|
try: |
|
FastLanguageModel.for_inference(model) |
|
except Exception as e: |
|
logger.info(f"Not an Unsloth model or inference mode failed: {e}") |
|
|
|
pass |
|
|
|
|
|
formatted_prompt = f"<s>[INST] {prompt} [/INST]" |
|
|
|
|
|
inputs = tokenizer( |
|
formatted_prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=512, |
|
padding=False |
|
) |
|
|
|
|
|
device = next(model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
logger.info(f"Input shape: {inputs['input_ids'].shape}") |
|
logger.info(f"Model device: {device}") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=inputs['input_ids'], |
|
attention_mask=inputs.get('attention_mask', None), |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
response = tokenizer.decode( |
|
outputs[0][inputs['input_ids'].shape[1]:], |
|
skip_special_tokens=True |
|
) |
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Generation error: {e}") |
|
import traceback |
|
logger.error(f"Full traceback: {traceback.format_exc()}") |
|
return f"❌ Error during generation: {str(e)}" |
|
|
|
def get_model_info(): |
|
"""Get information about the loaded model""" |
|
global model, tokenizer |
|
|
|
if model is None: |
|
return "No model loaded" |
|
|
|
try: |
|
model_type = type(model).__name__ |
|
device = str(next(model.parameters()).device) |
|
return f"Model: {model_type}, Device: {device}" |
|
except Exception as e: |
|
return f"Model info error: {e}" |
|
|
|
|
|
with gr.Blocks(title="🩺 Mistral-7B Bloodwork Analysis API") as demo: |
|
|
|
gr.Markdown(""" |
|
# 🩺 Mistral-7B Bloodwork Analysis API |
|
|
|
AI assistant for analyzing bloodwork results and medical data. |
|
This interface also provides API access at `/api/predict`. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
prompt_input = gr.Textbox( |
|
label="Medical Query", |
|
placeholder="Describe the bloodwork results you'd like analyzed...", |
|
lines=4, |
|
max_lines=10 |
|
) |
|
|
|
with gr.Row(): |
|
max_length_slider = gr.Slider( |
|
minimum=50, |
|
maximum=2500, |
|
value=256, |
|
step=10, |
|
label="Max Response Length" |
|
) |
|
|
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature" |
|
) |
|
|
|
top_p_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.1, |
|
label="Top-p" |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("🔬 Analyze", variant="primary") |
|
clear_btn = gr.Button("🗑️ Clear") |
|
|
|
with gr.Column(scale=2): |
|
output_text = gr.Textbox( |
|
label="Analysis Result", |
|
lines=8, |
|
max_lines=15, |
|
show_copy_button=True |
|
) |
|
|
|
model_info = gr.Textbox( |
|
label="Model Status", |
|
value="Loading...", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
["What does an elevated white blood cell count indicate?", 200, 0.7, 0.9], |
|
["Explain the significance of high cholesterol levels", 200, 0.7, 0.9], |
|
["What are normal ranges for blood glucose?", 200, 0.7, 0.9], |
|
["Patient has WBC: 15,000, RBC: 4.2, Hemoglobin: 12.5. What could this indicate?", 300, 0.7, 0.9], |
|
["Interpret these liver function tests: ALT: 65, AST: 58, Bilirubin: 2.1", 250, 0.7, 0.9] |
|
], |
|
inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider], |
|
outputs=output_text, |
|
fn=generate_response, |
|
cache_examples=False |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=generate_response, |
|
inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider], |
|
outputs=output_text |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ("", ""), |
|
outputs=[prompt_input, output_text] |
|
) |
|
|
|
|
|
demo.load( |
|
fn=get_model_info, |
|
outputs=model_info |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
|