Spaces:
Runtime error
Runtime error
File size: 8,565 Bytes
cdd5fed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import json
import gradio as gr
from typing import Any, Dict, List, Optional
from huggingface_hub import get_model_info, HfApi, hf_hub_url
from transformers import pipeline, Pipeline
# --- Config ---
# You gave a Space URL, so we'll assume your *model* lives at "mssaidat/Radiologist".
# If your actual model id is different, either:
# 1) change DEFAULT_MODEL_ID below, or
# 2) type the correct id in the UI and click "Load / Reload".
DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "mssaidat/Radiologist")
HF_TOKEN = os.getenv("HF_TOKEN", None)
# --- Globals ---
pl: Optional[Pipeline] = None
current_task: Optional[str] = None
current_model_id: str = DEFAULT_MODEL_ID
# --- Helpers ---
def _pretty(obj: Any) -> str:
try:
return json.dumps(obj, indent=2, ensure_ascii=False)
except Exception:
return str(obj)
def detect_task(model_id: str) -> str:
"""
Uses the model's Hub config to determine its pipeline task.
"""
info = get_model_info(model_id, token=HF_TOKEN)
# Preferred: pipeline_tag; Fallback: tags
if info.pipeline_tag:
return info.pipeline_tag
# Rare fallback if pipeline_tag missing:
tags = set(info.tags or [])
# crude heuristics
if "text-generation" in tags or "causal-lm" in tags:
return "text-generation"
if "text2text-generation" in tags or "seq2seq" in tags:
return "text2text-generation"
if "fill-mask" in tags:
return "fill-mask"
if "token-classification" in tags:
return "token-classification"
if "text-classification" in tags or "sentiment-analysis" in tags:
return "text-classification"
if "question-answering" in tags:
return "question-answering"
if "image-classification" in tags:
return "image-classification"
if "automatic-speech-recognition" in tags or "asr" in tags:
return "automatic-speech-recognition"
# Last resort
return "text-generation"
SUPPORTED = {
# text inputs
"text-generation",
"text2text-generation",
"fill-mask",
"token-classification",
"text-classification",
"question-answering",
# image input
"image-classification",
# audio input
"automatic-speech-recognition",
}
def load_pipeline(model_id: str):
global pl, current_task, current_model_id
task = detect_task(model_id)
if task not in SUPPORTED:
raise ValueError(
f"Detected task '{task}', which this demo doesn't handle yet. "
f"Supported: {sorted(list(SUPPORTED))}"
)
# device_map="auto" to use GPU if available in the Space
pl = pipeline(task=task, model=model_id, token=HF_TOKEN, device_map="auto")
current_task = task
current_model_id = model_id
return task
# --- Inference functions (simple, generic) ---
def infer_text(prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
if pl is None or current_task is None:
return "Model not loaded. Click 'Load / Reload' first."
if current_task in ["text-generation", "text2text-generation"]:
out = pl(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# pipelines may return list[dict] with 'generated_text' or 'summary_text'
if isinstance(out, list) and out and "generated_text" in out[0]:
return out[0]["generated_text"]
return _pretty(out)
elif current_task == "fill-mask":
out = pl(prompt)
return _pretty(out)
elif current_task == "text-classification":
out = pl(prompt, top_k=None) # full distribution if supported
return _pretty(out)
elif current_task == "token-classification": # NER
out = pl(prompt, aggregation_strategy="simple")
return _pretty(out)
elif current_task == "question-answering":
# Expect "prompt" like: "QUESTION <sep> CONTEXT"
# Minimal UX: split on first line break or <sep>
if "<sep>" in prompt:
q, c = prompt.split("<sep>", 1)
elif "\n" in prompt:
q, c = prompt.split("\n", 1)
else:
return ("For question-answering, provide input as:\n"
"question <sep> context\nor\nquestion\\ncontext")
out = pl(question=q.strip(), context=c.strip())
return _pretty(out)
else:
return f"Current task '{current_task}' uses a different tab."
def infer_image(image) -> str:
if pl is None or current_task is None:
return "Model not loaded. Click 'Load / Reload' first."
if current_task != "image-classification":
return f"Loaded task '{current_task}'. Use the appropriate tab."
out = pl(image)
return _pretty(out)
def infer_audio(audio) -> str:
if pl is None or current_task is None:
return "Model not loaded. Click 'Load / Reload' first."
if current_task != "automatic-speech-recognition":
return f"Loaded task '{current_task}'. Use the appropriate tab."
# gr.Audio returns (sample_rate, data) or a file path depending on type
out = pl(audio)
return _pretty(out)
def do_load(model_id: str):
try:
task = load_pipeline(model_id.strip())
msg = f"β
Loaded '{model_id}' as task: {task}"
hint = {
"text-generation": "Use the **Text** tab. Enter a prompt; tweak max_new_tokens/temperature/top_p.",
"text2text-generation": "Use the **Text** tab for instructions β outputs.",
"fill-mask": "Use the **Text** tab. Include the [MASK] token in your input.",
"text-classification": "Use the **Text** tab. Paste text to classify.",
"token-classification": "Use the **Text** tab. Paste text for NER.",
"question-answering": "Use the **Text** tab. Format: `question <sep> context` (or line break).",
"image-classification": "Use the **Image** tab and upload an image.",
"automatic-speech-recognition": "Use the **Audio** tab and upload/record audio."
}.get(task, "")
return msg + ("\n" + hint if hint else "")
except Exception as e:
return f"β Load failed: {e}"
# --- UI ---
with gr.Blocks(title="Radiologist β Hugging Face Space", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# π©Ί Radiologist β Universal Model Demo
This Space auto-detects your model's task from the Hub and gives you the right input panel.
**How to use**
1. Enter your model id (e.g., `mssaidat/Radiologist`) and click **Load / Reload**.
2. Use the matching tab (**Text**, **Image**, or **Audio**).
"""
)
with gr.Row():
model_id_box = gr.Textbox(
label="Model ID",
value=DEFAULT_MODEL_ID,
placeholder="e.g. mssaidat/Radiologist"
)
load_btn = gr.Button("Load / Reload", variant="primary")
status = gr.Markdown("*(No model loaded yet)*")
with gr.Tabs():
with gr.Tab("Text"):
text_in = gr.Textbox(
label="Text Input",
placeholder=(
"Enter a prompt.\n"
"For QA models: question <sep> context (or question on first line, context on second)"
),
lines=6
)
with gr.Row():
max_new_tokens = gr.Slider(1, 1024, value=256, step=1, label="max_new_tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
run_text = gr.Button("Run Text Inference")
text_out = gr.Code(label="Output", language="json")
with gr.Tab("Image"):
img_in = gr.Image(label="Upload Image", type="pil")
run_img = gr.Button("Run Image Inference")
img_out = gr.Code(label="Output", language="json")
with gr.Tab("Audio"):
aud_in = gr.Audio(label="Upload/Record Audio", type="filepath")
run_aud = gr.Button("Run ASR Inference")
aud_out = gr.Code(label="Output", language="json")
# Wire events
load_btn.click(fn=do_load, inputs=model_id_box, outputs=status)
run_text.click(fn=infer_text, inputs=[text_in, max_new_tokens, temperature, top_p], outputs=text_out)
run_img.click(fn=infer_image, inputs=img_in, outputs=img_out)
run_aud.click(fn=infer_audio, inputs=aud_in, outputs=aud_out)
demo.queue().launch()
|