Spaces:
Runtime error
Runtime error
import os | |
import json | |
import gradio as gr | |
from typing import Any, Dict, List, Optional | |
from huggingface_hub import get_model_info, HfApi, hf_hub_url | |
from transformers import pipeline, Pipeline | |
# --- Config --- | |
# You gave a Space URL, so we'll assume your *model* lives at "mssaidat/Radiologist". | |
# If your actual model id is different, either: | |
# 1) change DEFAULT_MODEL_ID below, or | |
# 2) type the correct id in the UI and click "Load / Reload". | |
DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "mssaidat/Radiologist") | |
HF_TOKEN = os.getenv("HF_TOKEN", None) | |
# --- Globals --- | |
pl: Optional[Pipeline] = None | |
current_task: Optional[str] = None | |
current_model_id: str = DEFAULT_MODEL_ID | |
# --- Helpers --- | |
def _pretty(obj: Any) -> str: | |
try: | |
return json.dumps(obj, indent=2, ensure_ascii=False) | |
except Exception: | |
return str(obj) | |
def detect_task(model_id: str) -> str: | |
""" | |
Uses the model's Hub config to determine its pipeline task. | |
""" | |
info = get_model_info(model_id, token=HF_TOKEN) | |
# Preferred: pipeline_tag; Fallback: tags | |
if info.pipeline_tag: | |
return info.pipeline_tag | |
# Rare fallback if pipeline_tag missing: | |
tags = set(info.tags or []) | |
# crude heuristics | |
if "text-generation" in tags or "causal-lm" in tags: | |
return "text-generation" | |
if "text2text-generation" in tags or "seq2seq" in tags: | |
return "text2text-generation" | |
if "fill-mask" in tags: | |
return "fill-mask" | |
if "token-classification" in tags: | |
return "token-classification" | |
if "text-classification" in tags or "sentiment-analysis" in tags: | |
return "text-classification" | |
if "question-answering" in tags: | |
return "question-answering" | |
if "image-classification" in tags: | |
return "image-classification" | |
if "automatic-speech-recognition" in tags or "asr" in tags: | |
return "automatic-speech-recognition" | |
# Last resort | |
return "text-generation" | |
SUPPORTED = { | |
# text inputs | |
"text-generation", | |
"text2text-generation", | |
"fill-mask", | |
"token-classification", | |
"text-classification", | |
"question-answering", | |
# image input | |
"image-classification", | |
# audio input | |
"automatic-speech-recognition", | |
} | |
def load_pipeline(model_id: str): | |
global pl, current_task, current_model_id | |
task = detect_task(model_id) | |
if task not in SUPPORTED: | |
raise ValueError( | |
f"Detected task '{task}', which this demo doesn't handle yet. " | |
f"Supported: {sorted(list(SUPPORTED))}" | |
) | |
# device_map="auto" to use GPU if available in the Space | |
pl = pipeline(task=task, model=model_id, token=HF_TOKEN, device_map="auto") | |
current_task = task | |
current_model_id = model_id | |
return task | |
# --- Inference functions (simple, generic) --- | |
def infer_text(prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: | |
if pl is None or current_task is None: | |
return "Model not loaded. Click 'Load / Reload' first." | |
if current_task in ["text-generation", "text2text-generation"]: | |
out = pl( | |
prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# pipelines may return list[dict] with 'generated_text' or 'summary_text' | |
if isinstance(out, list) and out and "generated_text" in out[0]: | |
return out[0]["generated_text"] | |
return _pretty(out) | |
elif current_task == "fill-mask": | |
out = pl(prompt) | |
return _pretty(out) | |
elif current_task == "text-classification": | |
out = pl(prompt, top_k=None) # full distribution if supported | |
return _pretty(out) | |
elif current_task == "token-classification": # NER | |
out = pl(prompt, aggregation_strategy="simple") | |
return _pretty(out) | |
elif current_task == "question-answering": | |
# Expect "prompt" like: "QUESTION <sep> CONTEXT" | |
# Minimal UX: split on first line break or <sep> | |
if "<sep>" in prompt: | |
q, c = prompt.split("<sep>", 1) | |
elif "\n" in prompt: | |
q, c = prompt.split("\n", 1) | |
else: | |
return ("For question-answering, provide input as:\n" | |
"question <sep> context\nor\nquestion\\ncontext") | |
out = pl(question=q.strip(), context=c.strip()) | |
return _pretty(out) | |
else: | |
return f"Current task '{current_task}' uses a different tab." | |
def infer_image(image) -> str: | |
if pl is None or current_task is None: | |
return "Model not loaded. Click 'Load / Reload' first." | |
if current_task != "image-classification": | |
return f"Loaded task '{current_task}'. Use the appropriate tab." | |
out = pl(image) | |
return _pretty(out) | |
def infer_audio(audio) -> str: | |
if pl is None or current_task is None: | |
return "Model not loaded. Click 'Load / Reload' first." | |
if current_task != "automatic-speech-recognition": | |
return f"Loaded task '{current_task}'. Use the appropriate tab." | |
# gr.Audio returns (sample_rate, data) or a file path depending on type | |
out = pl(audio) | |
return _pretty(out) | |
def do_load(model_id: str): | |
try: | |
task = load_pipeline(model_id.strip()) | |
msg = f"β Loaded '{model_id}' as task: {task}" | |
hint = { | |
"text-generation": "Use the **Text** tab. Enter a prompt; tweak max_new_tokens/temperature/top_p.", | |
"text2text-generation": "Use the **Text** tab for instructions β outputs.", | |
"fill-mask": "Use the **Text** tab. Include the [MASK] token in your input.", | |
"text-classification": "Use the **Text** tab. Paste text to classify.", | |
"token-classification": "Use the **Text** tab. Paste text for NER.", | |
"question-answering": "Use the **Text** tab. Format: `question <sep> context` (or line break).", | |
"image-classification": "Use the **Image** tab and upload an image.", | |
"automatic-speech-recognition": "Use the **Audio** tab and upload/record audio." | |
}.get(task, "") | |
return msg + ("\n" + hint if hint else "") | |
except Exception as e: | |
return f"β Load failed: {e}" | |
# --- UI --- | |
with gr.Blocks(title="Radiologist β Hugging Face Space", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# π©Ί Radiologist β Universal Model Demo | |
This Space auto-detects your model's task from the Hub and gives you the right input panel. | |
**How to use** | |
1. Enter your model id (e.g., `mssaidat/Radiologist`) and click **Load / Reload**. | |
2. Use the matching tab (**Text**, **Image**, or **Audio**). | |
""" | |
) | |
with gr.Row(): | |
model_id_box = gr.Textbox( | |
label="Model ID", | |
value=DEFAULT_MODEL_ID, | |
placeholder="e.g. mssaidat/Radiologist" | |
) | |
load_btn = gr.Button("Load / Reload", variant="primary") | |
status = gr.Markdown("*(No model loaded yet)*") | |
with gr.Tabs(): | |
with gr.Tab("Text"): | |
text_in = gr.Textbox( | |
label="Text Input", | |
placeholder=( | |
"Enter a prompt.\n" | |
"For QA models: question <sep> context (or question on first line, context on second)" | |
), | |
lines=6 | |
) | |
with gr.Row(): | |
max_new_tokens = gr.Slider(1, 1024, value=256, step=1, label="max_new_tokens") | |
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p") | |
run_text = gr.Button("Run Text Inference") | |
text_out = gr.Code(label="Output", language="json") | |
with gr.Tab("Image"): | |
img_in = gr.Image(label="Upload Image", type="pil") | |
run_img = gr.Button("Run Image Inference") | |
img_out = gr.Code(label="Output", language="json") | |
with gr.Tab("Audio"): | |
aud_in = gr.Audio(label="Upload/Record Audio", type="filepath") | |
run_aud = gr.Button("Run ASR Inference") | |
aud_out = gr.Code(label="Output", language="json") | |
# Wire events | |
load_btn.click(fn=do_load, inputs=model_id_box, outputs=status) | |
run_text.click(fn=infer_text, inputs=[text_in, max_new_tokens, temperature, top_p], outputs=text_out) | |
run_img.click(fn=infer_image, inputs=img_in, outputs=img_out) | |
run_aud.click(fn=infer_audio, inputs=aud_in, outputs=aud_out) | |
demo.queue().launch() | |