""" Gradio app that: - Uses a local model if torch is installed, - Otherwise tries Hugging Face InferenceClient, - Otherwise falls back to legacy InferenceApi with task="text-generation". Make sure HF_TOKEN is set in Space secrets if your model is private. """ import os from typing import Optional import gradio as gr import torch MODEL_ID = "marvinisjarvis/radio_model" HF_TOKEN = os.environ.get("HF_TOKEN", None) # Flags & clients LOCAL_AVAILABLE = False INFERENCE_CLIENT_AVAILABLE = False INFERENCE_API_AVAILABLE = False # Attempt local loading (torch + transformers) try: import torch from transformers import AutoTokenizer, AutoModelForCausalLM device = "cuda" if torch.cuda.is_available() else "cpu" def try_load_local(): print("Attempting to load local model...") tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, use_fast=True, token=HF_TOKEN ) kwargs = {"trust_remote_code": True, "token": HF_TOKEN, "low_cpu_mem_usage": True} if device == "cuda": kwargs.update({"device_map": "auto", "torch_dtype": torch.float16}) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) return tokenizer, model try: tokenizer, model = try_load_local() LOCAL_AVAILABLE = True print("Local model loaded.") except Exception as e: print("Local model load failed:", e) LOCAL_AVAILABLE = False except Exception as e: print("Torch not available or failed to import:", e) LOCAL_AVAILABLE = False # Attempt to use InferenceClient (preferred) if not LOCAL_AVAILABLE: try: from huggingface_hub import InferenceClient client = InferenceClient(token=HF_TOKEN) INFERENCE_CLIENT_AVAILABLE = True print("InferenceClient available - will use remote text-generation via InferenceClient.") except Exception as e: print("InferenceClient not available:", e) INFERENCE_CLIENT_AVAILABLE = False # Fallback to legacy InferenceApi with explicit task inference_api = None if (not LOCAL_AVAILABLE) and (not INFERENCE_CLIENT_AVAILABLE): try: from huggingface_hub import InferenceApi # Explicitly specify task to avoid "Task not specified" errors inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_TOKEN, task="text-generation") INFERENCE_API_AVAILABLE = True print("Using legacy InferenceApi with task='text-generation'.") except Exception as e: print("Hugging Face InferenceApi not available or failed:", e) INFERENCE_API_AVAILABLE = False # Generation wrapper that handles all three paths def generate_answer( prompt: str, max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, num_beams: int = 1, stop_token: Optional[str] = None, ) -> str: if not prompt or prompt.strip() == "": return "Please enter a prompt." # Local path if LOCAL_AVAILABLE: try: inputs = tokenizer(prompt, return_tensors="pt", truncation=True) device0 = next(model.parameters()).device input_ids = inputs["input_ids"].to(device0) attention_mask = inputs.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(device0) gen_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), num_beams=int(num_beams), eos_token_id=getattr(tokenizer, "eos_token_id", None), pad_token_id=getattr(tokenizer, "pad_token_id", None), do_sample=(float(temperature) > 0) and (int(num_beams) == 1), ) outputs = model.generate(**gen_kwargs) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) result = decoded[len(prompt) :].strip() if decoded.startswith(prompt) else decoded.strip() if stop_token: idx = result.find(stop_token) if idx != -1: result = result[:idx].strip() return result except Exception as e: print("Local generation failed, falling back to remote. Error:", e) # InferenceClient path (preferred remote) if INFERENCE_CLIENT_AVAILABLE: try: # The InferenceClient text_generation method takes kwargs for parameters response = client.text_generation( model=MODEL_ID, prompt=prompt, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), num_beams=int(num_beams), ) # The response is the generated text string out = response if out.startswith(prompt): out = out[len(prompt) :].strip() if stop_token: idx = out.find(stop_token) if idx != -1: out = out[:idx].strip() return out.strip() except Exception as e: print("InferenceClient call failed, will try legacy InferenceApi. Error:", e) # Legacy InferenceApi fallback (explicit task) if INFERENCE_API_AVAILABLE and inference_api is not None: try: params = {"max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p)} res = inference_api(prompt, params=params) # normalize response if isinstance(res, str): out = res elif isinstance(res, dict) and "generated_text" in res: out = res["generated_text"] elif isinstance(res, list) and res and isinstance(res[0], dict) and "generated_text" in res[0]: out = res[0]["generated_text"] else: out = str(res) if out.startswith(prompt): out = out[len(prompt) :].strip() if stop_token: idx = out.find(stop_token) if idx != -1: out = out[:idx].strip() return out.strip() except Exception as e: print("Legacy InferenceApi failed:", e) return f"Remote inference failed: {e}" return ("No inference path available. Install torch for local inference or ensure HF_TOKEN is set and huggingface_hub supports InferenceClient/InferenceApi.") # --- Gradio UI --- title = "RadioModel — Radiology Q&A (Mistral 7B fine-tuned)" description = """ Demo for marvinisjarvis/radio_model. Tries local inference first; otherwise uses Hugging Face remote inference. If your model is private, add HF_TOKEN in Space secrets. Not for clinical use. """ with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description) with gr.Row(): with gr.Column(scale=3): prompt_input = gr.Textbox(label="Enter your radiology question", lines=6) submit = gr.Button("Generate Answer") examples = gr.Examples( examples=[ "What does an X-ray of pneumonia typically show?", "How can you differentiate a benign lung nodule from a malignant one on CT?", "What are common signs of bone fracture on X-rays?", "Which imaging modality is best for detecting small brain tumors?" ], inputs=prompt_input ) with gr.Column(scale=2): max_tokens = gr.Slider(32, 1024, value=256, step=32, label="Max New Tokens") temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p") num_beams = gr.Slider(1, 5, value=1, step=1, label="Num Beams") stop_token = gr.Textbox(label="Optional stop token", placeholder="e.g., ### or ", lines=1) output = gr.Textbox(label="Model output", lines=14) def on_submit(prompt, max_new_tokens, temperature, top_p, num_beams, stop_token): return generate_answer(prompt, max_new_tokens, temperature, top_p, int(num_beams), stop_token) submit.click(on_submit, inputs=[prompt_input, max_tokens, temperature, top_p, num_beams, stop_token], outputs=output) gr.Markdown("Disclaimer: This demo is for evaluation and research. It is not a medical device.") if __name__ == "__main__": demo.launch()