radiotest / app.py
marvinisjarvis's picture
Update app.py
bd4ae36 verified
"""
Gradio app that:
- Uses a local model if torch is installed,
- Otherwise tries Hugging Face InferenceClient,
- Otherwise falls back to legacy InferenceApi with task="text-generation".
Make sure HF_TOKEN is set in Space secrets if your model is private.
"""
import os
from typing import Optional
import gradio as gr
import torch
MODEL_ID = "marvinisjarvis/radio_model"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Flags & clients
LOCAL_AVAILABLE = False
INFERENCE_CLIENT_AVAILABLE = False
INFERENCE_API_AVAILABLE = False
# Attempt local loading (torch + transformers)
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
def try_load_local():
print("Attempting to load local model...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True, use_fast=True, token=HF_TOKEN
)
kwargs = {"trust_remote_code": True, "token": HF_TOKEN, "low_cpu_mem_usage": True}
if device == "cuda":
kwargs.update({"device_map": "auto", "torch_dtype": torch.float16})
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
return tokenizer, model
try:
tokenizer, model = try_load_local()
LOCAL_AVAILABLE = True
print("Local model loaded.")
except Exception as e:
print("Local model load failed:", e)
LOCAL_AVAILABLE = False
except Exception as e:
print("Torch not available or failed to import:", e)
LOCAL_AVAILABLE = False
# Attempt to use InferenceClient (preferred)
if not LOCAL_AVAILABLE:
try:
from huggingface_hub import InferenceClient
client = InferenceClient(token=HF_TOKEN)
INFERENCE_CLIENT_AVAILABLE = True
print("InferenceClient available - will use remote text-generation via InferenceClient.")
except Exception as e:
print("InferenceClient not available:", e)
INFERENCE_CLIENT_AVAILABLE = False
# Fallback to legacy InferenceApi with explicit task
inference_api = None
if (not LOCAL_AVAILABLE) and (not INFERENCE_CLIENT_AVAILABLE):
try:
from huggingface_hub import InferenceApi
# Explicitly specify task to avoid "Task not specified" errors
inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_TOKEN, task="text-generation")
INFERENCE_API_AVAILABLE = True
print("Using legacy InferenceApi with task='text-generation'.")
except Exception as e:
print("Hugging Face InferenceApi not available or failed:", e)
INFERENCE_API_AVAILABLE = False
# Generation wrapper that handles all three paths
def generate_answer(
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
num_beams: int = 1,
stop_token: Optional[str] = None,
) -> str:
if not prompt or prompt.strip() == "":
return "Please enter a prompt."
# Local path
if LOCAL_AVAILABLE:
try:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
device0 = next(model.parameters()).device
input_ids = inputs["input_ids"].to(device0)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device0)
gen_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
num_beams=int(num_beams),
eos_token_id=getattr(tokenizer, "eos_token_id", None),
pad_token_id=getattr(tokenizer, "pad_token_id", None),
do_sample=(float(temperature) > 0) and (int(num_beams) == 1),
)
outputs = model.generate(**gen_kwargs)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
result = decoded[len(prompt) :].strip() if decoded.startswith(prompt) else decoded.strip()
if stop_token:
idx = result.find(stop_token)
if idx != -1:
result = result[:idx].strip()
return result
except Exception as e:
print("Local generation failed, falling back to remote. Error:", e)
# InferenceClient path (preferred remote)
if INFERENCE_CLIENT_AVAILABLE:
try:
# The InferenceClient text_generation method takes kwargs for parameters
response = client.text_generation(
model=MODEL_ID,
prompt=prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
num_beams=int(num_beams),
)
# The response is the generated text string
out = response
if out.startswith(prompt):
out = out[len(prompt) :].strip()
if stop_token:
idx = out.find(stop_token)
if idx != -1:
out = out[:idx].strip()
return out.strip()
except Exception as e:
print("InferenceClient call failed, will try legacy InferenceApi. Error:", e)
# Legacy InferenceApi fallback (explicit task)
if INFERENCE_API_AVAILABLE and inference_api is not None:
try:
params = {"max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p)}
res = inference_api(prompt, params=params)
# normalize response
if isinstance(res, str):
out = res
elif isinstance(res, dict) and "generated_text" in res:
out = res["generated_text"]
elif isinstance(res, list) and res and isinstance(res[0], dict) and "generated_text" in res[0]:
out = res[0]["generated_text"]
else:
out = str(res)
if out.startswith(prompt):
out = out[len(prompt) :].strip()
if stop_token:
idx = out.find(stop_token)
if idx != -1:
out = out[:idx].strip()
return out.strip()
except Exception as e:
print("Legacy InferenceApi failed:", e)
return f"Remote inference failed: {e}"
return ("No inference path available. Install torch for local inference or ensure HF_TOKEN is set and huggingface_hub
supports InferenceClient/InferenceApi.")
# --- Gradio UI ---
title = "RadioModel — Radiology Q&A (Mistral 7B fine-tuned)"
description = """
Demo for marvinisjarvis/radio_model.
Tries local inference first; otherwise uses Hugging Face remote inference.
If your model is private, add HF_TOKEN in Space secrets. Not for clinical use.
"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=3):
prompt_input = gr.Textbox(label="Enter your radiology question", lines=6)
submit = gr.Button("Generate Answer")
examples = gr.Examples(
examples=[
"What does an X-ray of pneumonia typically show?",
"How can you differentiate a benign lung nodule from a malignant one on CT?",
"What are common signs of bone fracture on X-rays?",
"Which imaging modality is best for detecting small brain tumors?"
],
inputs=prompt_input
)
with gr.Column(scale=2):
max_tokens = gr.Slider(32, 1024, value=256, step=32, label="Max New Tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
num_beams = gr.Slider(1, 5, value=1, step=1, label="Num Beams")
stop_token = gr.Textbox(label="Optional stop token", placeholder="e.g., ### or <END>", lines=1)
output = gr.Textbox(label="Model output", lines=14)
def on_submit(prompt, max_new_tokens, temperature, top_p, num_beams, stop_token):
return generate_answer(prompt, max_new_tokens, temperature, top_p, int(num_beams), stop_token)
submit.click(on_submit, inputs=[prompt_input, max_tokens, temperature, top_p, num_beams, stop_token], outputs=output)
gr.Markdown("Disclaimer: This demo is for evaluation and research. It is not a medical device.")
if __name__ == "__main__":
demo.launch()