Radiologist / app.py
mssaidat's picture
Create app.py
cdd5fed verified
import os
import json
import gradio as gr
from typing import Any, Dict, List, Optional
from huggingface_hub import get_model_info, HfApi, hf_hub_url
from transformers import pipeline, Pipeline
# --- Config ---
# You gave a Space URL, so we'll assume your *model* lives at "mssaidat/Radiologist".
# If your actual model id is different, either:
# 1) change DEFAULT_MODEL_ID below, or
# 2) type the correct id in the UI and click "Load / Reload".
DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "mssaidat/Radiologist")
HF_TOKEN = os.getenv("HF_TOKEN", None)
# --- Globals ---
pl: Optional[Pipeline] = None
current_task: Optional[str] = None
current_model_id: str = DEFAULT_MODEL_ID
# --- Helpers ---
def _pretty(obj: Any) -> str:
try:
return json.dumps(obj, indent=2, ensure_ascii=False)
except Exception:
return str(obj)
def detect_task(model_id: str) -> str:
"""
Uses the model's Hub config to determine its pipeline task.
"""
info = get_model_info(model_id, token=HF_TOKEN)
# Preferred: pipeline_tag; Fallback: tags
if info.pipeline_tag:
return info.pipeline_tag
# Rare fallback if pipeline_tag missing:
tags = set(info.tags or [])
# crude heuristics
if "text-generation" in tags or "causal-lm" in tags:
return "text-generation"
if "text2text-generation" in tags or "seq2seq" in tags:
return "text2text-generation"
if "fill-mask" in tags:
return "fill-mask"
if "token-classification" in tags:
return "token-classification"
if "text-classification" in tags or "sentiment-analysis" in tags:
return "text-classification"
if "question-answering" in tags:
return "question-answering"
if "image-classification" in tags:
return "image-classification"
if "automatic-speech-recognition" in tags or "asr" in tags:
return "automatic-speech-recognition"
# Last resort
return "text-generation"
SUPPORTED = {
# text inputs
"text-generation",
"text2text-generation",
"fill-mask",
"token-classification",
"text-classification",
"question-answering",
# image input
"image-classification",
# audio input
"automatic-speech-recognition",
}
def load_pipeline(model_id: str):
global pl, current_task, current_model_id
task = detect_task(model_id)
if task not in SUPPORTED:
raise ValueError(
f"Detected task '{task}', which this demo doesn't handle yet. "
f"Supported: {sorted(list(SUPPORTED))}"
)
# device_map="auto" to use GPU if available in the Space
pl = pipeline(task=task, model=model_id, token=HF_TOKEN, device_map="auto")
current_task = task
current_model_id = model_id
return task
# --- Inference functions (simple, generic) ---
def infer_text(prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
if pl is None or current_task is None:
return "Model not loaded. Click 'Load / Reload' first."
if current_task in ["text-generation", "text2text-generation"]:
out = pl(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# pipelines may return list[dict] with 'generated_text' or 'summary_text'
if isinstance(out, list) and out and "generated_text" in out[0]:
return out[0]["generated_text"]
return _pretty(out)
elif current_task == "fill-mask":
out = pl(prompt)
return _pretty(out)
elif current_task == "text-classification":
out = pl(prompt, top_k=None) # full distribution if supported
return _pretty(out)
elif current_task == "token-classification": # NER
out = pl(prompt, aggregation_strategy="simple")
return _pretty(out)
elif current_task == "question-answering":
# Expect "prompt" like: "QUESTION <sep> CONTEXT"
# Minimal UX: split on first line break or <sep>
if "<sep>" in prompt:
q, c = prompt.split("<sep>", 1)
elif "\n" in prompt:
q, c = prompt.split("\n", 1)
else:
return ("For question-answering, provide input as:\n"
"question <sep> context\nor\nquestion\\ncontext")
out = pl(question=q.strip(), context=c.strip())
return _pretty(out)
else:
return f"Current task '{current_task}' uses a different tab."
def infer_image(image) -> str:
if pl is None or current_task is None:
return "Model not loaded. Click 'Load / Reload' first."
if current_task != "image-classification":
return f"Loaded task '{current_task}'. Use the appropriate tab."
out = pl(image)
return _pretty(out)
def infer_audio(audio) -> str:
if pl is None or current_task is None:
return "Model not loaded. Click 'Load / Reload' first."
if current_task != "automatic-speech-recognition":
return f"Loaded task '{current_task}'. Use the appropriate tab."
# gr.Audio returns (sample_rate, data) or a file path depending on type
out = pl(audio)
return _pretty(out)
def do_load(model_id: str):
try:
task = load_pipeline(model_id.strip())
msg = f"βœ… Loaded '{model_id}' as task: {task}"
hint = {
"text-generation": "Use the **Text** tab. Enter a prompt; tweak max_new_tokens/temperature/top_p.",
"text2text-generation": "Use the **Text** tab for instructions β†’ outputs.",
"fill-mask": "Use the **Text** tab. Include the [MASK] token in your input.",
"text-classification": "Use the **Text** tab. Paste text to classify.",
"token-classification": "Use the **Text** tab. Paste text for NER.",
"question-answering": "Use the **Text** tab. Format: `question <sep> context` (or line break).",
"image-classification": "Use the **Image** tab and upload an image.",
"automatic-speech-recognition": "Use the **Audio** tab and upload/record audio."
}.get(task, "")
return msg + ("\n" + hint if hint else "")
except Exception as e:
return f"❌ Load failed: {e}"
# --- UI ---
with gr.Blocks(title="Radiologist β€” Hugging Face Space", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🩺 Radiologist β€” Universal Model Demo
This Space auto-detects your model's task from the Hub and gives you the right input panel.
**How to use**
1. Enter your model id (e.g., `mssaidat/Radiologist`) and click **Load / Reload**.
2. Use the matching tab (**Text**, **Image**, or **Audio**).
"""
)
with gr.Row():
model_id_box = gr.Textbox(
label="Model ID",
value=DEFAULT_MODEL_ID,
placeholder="e.g. mssaidat/Radiologist"
)
load_btn = gr.Button("Load / Reload", variant="primary")
status = gr.Markdown("*(No model loaded yet)*")
with gr.Tabs():
with gr.Tab("Text"):
text_in = gr.Textbox(
label="Text Input",
placeholder=(
"Enter a prompt.\n"
"For QA models: question <sep> context (or question on first line, context on second)"
),
lines=6
)
with gr.Row():
max_new_tokens = gr.Slider(1, 1024, value=256, step=1, label="max_new_tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
run_text = gr.Button("Run Text Inference")
text_out = gr.Code(label="Output", language="json")
with gr.Tab("Image"):
img_in = gr.Image(label="Upload Image", type="pil")
run_img = gr.Button("Run Image Inference")
img_out = gr.Code(label="Output", language="json")
with gr.Tab("Audio"):
aud_in = gr.Audio(label="Upload/Record Audio", type="filepath")
run_aud = gr.Button("Run ASR Inference")
aud_out = gr.Code(label="Output", language="json")
# Wire events
load_btn.click(fn=do_load, inputs=model_id_box, outputs=status)
run_text.click(fn=infer_text, inputs=[text_in, max_new_tokens, temperature, top_p], outputs=text_out)
run_img.click(fn=infer_image, inputs=img_in, outputs=img_out)
run_aud.click(fn=infer_audio, inputs=aud_in, outputs=aud_out)
demo.queue().launch()