import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # Add this import at the top from unsloth import FastLanguageModel from peft import PeftModel import logging import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for model and tokenizer model = None tokenizer = None def load_model(): """Load the model and tokenizer using Unsloth method""" global model, tokenizer try: logger.info("Loading model with Unsloth...") # Load the base model that you used for training base_model_name = "unsloth/mistral-7b-bnb-4bit" logger.info(f"Loading base model: {base_model_name}") base_model, base_tokenizer = FastLanguageModel.from_pretrained( model_name=base_model_name, max_seq_length=2048, dtype=None, load_in_4bit=True, ) logger.info("Base model loaded, now trying to load your merged model...") # Since your model is merged, try to load it directly try: logger.info("Loading your merged model...") model = AutoModelForCausalLM.from_pretrained( "ndhaliwal59/mistral-7b-bloodwork-analysis-merged", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained("ndhaliwal59/mistral-7b-bloodwork-analysis-merged") logger.info("Your merged model loaded successfully!") except Exception as merged_error: logger.error(f"Failed to load merged model: {merged_error}") logger.info("Using base model instead...") model = base_model tokenizer = base_tokenizer # Set up tokenizer if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Model setup completed successfully!") return True except Exception as e: logger.error(f"Error loading model: {e}") logger.error("Trying fallback loading method...") # Fallback: try loading with standard transformers try: logger.info("Loading fallback model for testing...") model = AutoModelForCausalLM.from_pretrained( "microsoft/DialoGPT-medium", torch_dtype=torch.float16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Fallback model loaded for testing!") return True except Exception as fallback_error: logger.error(f"Fallback loading also failed: {fallback_error}") return False def generate_response(prompt, max_length=256, temperature=0.7, top_p=0.9): """Generate response from the model""" global model, tokenizer # Load model if not already loaded if model is None or tokenizer is None: logger.info("Model not loaded, attempting to load...") if not load_model(): return "❌ Error: Failed to load model. Please check the logs for details." try: # Prepare the model for inference if it's an Unsloth model try: FastLanguageModel.for_inference(model) except Exception as e: logger.info(f"Not an Unsloth model or inference mode failed: {e}") # If it's not an Unsloth model, continue normally pass # Format the prompt properly formatted_prompt = f"[INST] {prompt} [/INST]" # Tokenize input - fix the tokenizer processing inputs = tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=512, padding=False # Changed from True to False ) # Move inputs to the same device as model device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} logger.info(f"Input shape: {inputs['input_ids'].shape}") logger.info(f"Model device: {device}") # Generate response with torch.no_grad(): outputs = model.generate( input_ids=inputs['input_ids'], # Explicitly pass input_ids attention_mask=inputs.get('attention_mask', None), max_new_tokens=max_length, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1 ) # Decode only the new tokens (skip the input prompt) response = tokenizer.decode( outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ) return response.strip() except Exception as e: logger.error(f"Generation error: {e}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") return f"❌ Error during generation: {str(e)}" def get_model_info(): """Get information about the loaded model""" global model, tokenizer if model is None: return "No model loaded" try: model_type = type(model).__name__ device = str(next(model.parameters()).device) return f"Model: {model_type}, Device: {device}" except Exception as e: return f"Model info error: {e}" # Create Gradio interface with gr.Blocks(title="🩺 Mistral-7B Bloodwork Analysis API") as demo: gr.Markdown(""" # 🩺 Mistral-7B Bloodwork Analysis API AI assistant for analyzing bloodwork results and medical data. This interface also provides API access at `/api/predict`. """) with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Medical Query", placeholder="Describe the bloodwork results you'd like analyzed...", lines=4, max_lines=10 ) with gr.Row(): max_length_slider = gr.Slider( minimum=50, maximum=2500, value=256, step=10, label="Max Response Length" ) temperature_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p" ) with gr.Row(): submit_btn = gr.Button("🔬 Analyze", variant="primary") clear_btn = gr.Button("🗑️ Clear") with gr.Column(scale=2): output_text = gr.Textbox( label="Analysis Result", lines=8, max_lines=15, show_copy_button=True ) model_info = gr.Textbox( label="Model Status", value="Loading...", interactive=False ) with gr.Row(): gr.Examples( examples=[ ["What does an elevated white blood cell count indicate?", 200, 0.7, 0.9], ["Explain the significance of high cholesterol levels", 200, 0.7, 0.9], ["What are normal ranges for blood glucose?", 200, 0.7, 0.9], ["Patient has WBC: 15,000, RBC: 4.2, Hemoglobin: 12.5. What could this indicate?", 300, 0.7, 0.9], ["Interpret these liver function tests: ALT: 65, AST: 58, Bilirubin: 2.1", 250, 0.7, 0.9] ], inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider], outputs=output_text, fn=generate_response, cache_examples=False ) # Event handlers submit_btn.click( fn=generate_response, inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider], outputs=output_text ) clear_btn.click( fn=lambda: ("", ""), outputs=[prompt_input, output_text] ) # Update model info when the interface loads demo.load( fn=get_model_info, outputs=model_info ) # Launch the interface if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )