import os import json import gradio as gr from typing import Any, Dict, List, Optional from huggingface_hub import get_model_info, HfApi, hf_hub_url from transformers import pipeline, Pipeline # --- Config --- # You gave a Space URL, so we'll assume your *model* lives at "mssaidat/Radiologist". # If your actual model id is different, either: # 1) change DEFAULT_MODEL_ID below, or # 2) type the correct id in the UI and click "Load / Reload". DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "mssaidat/Radiologist") HF_TOKEN = os.getenv("HF_TOKEN", None) # --- Globals --- pl: Optional[Pipeline] = None current_task: Optional[str] = None current_model_id: str = DEFAULT_MODEL_ID # --- Helpers --- def _pretty(obj: Any) -> str: try: return json.dumps(obj, indent=2, ensure_ascii=False) except Exception: return str(obj) def detect_task(model_id: str) -> str: """ Uses the model's Hub config to determine its pipeline task. """ info = get_model_info(model_id, token=HF_TOKEN) # Preferred: pipeline_tag; Fallback: tags if info.pipeline_tag: return info.pipeline_tag # Rare fallback if pipeline_tag missing: tags = set(info.tags or []) # crude heuristics if "text-generation" in tags or "causal-lm" in tags: return "text-generation" if "text2text-generation" in tags or "seq2seq" in tags: return "text2text-generation" if "fill-mask" in tags: return "fill-mask" if "token-classification" in tags: return "token-classification" if "text-classification" in tags or "sentiment-analysis" in tags: return "text-classification" if "question-answering" in tags: return "question-answering" if "image-classification" in tags: return "image-classification" if "automatic-speech-recognition" in tags or "asr" in tags: return "automatic-speech-recognition" # Last resort return "text-generation" SUPPORTED = { # text inputs "text-generation", "text2text-generation", "fill-mask", "token-classification", "text-classification", "question-answering", # image input "image-classification", # audio input "automatic-speech-recognition", } def load_pipeline(model_id: str): global pl, current_task, current_model_id task = detect_task(model_id) if task not in SUPPORTED: raise ValueError( f"Detected task '{task}', which this demo doesn't handle yet. " f"Supported: {sorted(list(SUPPORTED))}" ) # device_map="auto" to use GPU if available in the Space pl = pipeline(task=task, model=model_id, token=HF_TOKEN, device_map="auto") current_task = task current_model_id = model_id return task # --- Inference functions (simple, generic) --- def infer_text(prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: if pl is None or current_task is None: return "Model not loaded. Click 'Load / Reload' first." if current_task in ["text-generation", "text2text-generation"]: out = pl( prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True ) # pipelines may return list[dict] with 'generated_text' or 'summary_text' if isinstance(out, list) and out and "generated_text" in out[0]: return out[0]["generated_text"] return _pretty(out) elif current_task == "fill-mask": out = pl(prompt) return _pretty(out) elif current_task == "text-classification": out = pl(prompt, top_k=None) # full distribution if supported return _pretty(out) elif current_task == "token-classification": # NER out = pl(prompt, aggregation_strategy="simple") return _pretty(out) elif current_task == "question-answering": # Expect "prompt" like: "QUESTION CONTEXT" # Minimal UX: split on first line break or if "" in prompt: q, c = prompt.split("", 1) elif "\n" in prompt: q, c = prompt.split("\n", 1) else: return ("For question-answering, provide input as:\n" "question context\nor\nquestion\\ncontext") out = pl(question=q.strip(), context=c.strip()) return _pretty(out) else: return f"Current task '{current_task}' uses a different tab." def infer_image(image) -> str: if pl is None or current_task is None: return "Model not loaded. Click 'Load / Reload' first." if current_task != "image-classification": return f"Loaded task '{current_task}'. Use the appropriate tab." out = pl(image) return _pretty(out) def infer_audio(audio) -> str: if pl is None or current_task is None: return "Model not loaded. Click 'Load / Reload' first." if current_task != "automatic-speech-recognition": return f"Loaded task '{current_task}'. Use the appropriate tab." # gr.Audio returns (sample_rate, data) or a file path depending on type out = pl(audio) return _pretty(out) def do_load(model_id: str): try: task = load_pipeline(model_id.strip()) msg = f"✅ Loaded '{model_id}' as task: {task}" hint = { "text-generation": "Use the **Text** tab. Enter a prompt; tweak max_new_tokens/temperature/top_p.", "text2text-generation": "Use the **Text** tab for instructions → outputs.", "fill-mask": "Use the **Text** tab. Include the [MASK] token in your input.", "text-classification": "Use the **Text** tab. Paste text to classify.", "token-classification": "Use the **Text** tab. Paste text for NER.", "question-answering": "Use the **Text** tab. Format: `question context` (or line break).", "image-classification": "Use the **Image** tab and upload an image.", "automatic-speech-recognition": "Use the **Audio** tab and upload/record audio." }.get(task, "") return msg + ("\n" + hint if hint else "") except Exception as e: return f"❌ Load failed: {e}" # --- UI --- with gr.Blocks(title="Radiologist — Hugging Face Space", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🩺 Radiologist — Universal Model Demo This Space auto-detects your model's task from the Hub and gives you the right input panel. **How to use** 1. Enter your model id (e.g., `mssaidat/Radiologist`) and click **Load / Reload**. 2. Use the matching tab (**Text**, **Image**, or **Audio**). """ ) with gr.Row(): model_id_box = gr.Textbox( label="Model ID", value=DEFAULT_MODEL_ID, placeholder="e.g. mssaidat/Radiologist" ) load_btn = gr.Button("Load / Reload", variant="primary") status = gr.Markdown("*(No model loaded yet)*") with gr.Tabs(): with gr.Tab("Text"): text_in = gr.Textbox( label="Text Input", placeholder=( "Enter a prompt.\n" "For QA models: question context (or question on first line, context on second)" ), lines=6 ) with gr.Row(): max_new_tokens = gr.Slider(1, 1024, value=256, step=1, label="max_new_tokens") temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature") top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p") run_text = gr.Button("Run Text Inference") text_out = gr.Code(label="Output", language="json") with gr.Tab("Image"): img_in = gr.Image(label="Upload Image", type="pil") run_img = gr.Button("Run Image Inference") img_out = gr.Code(label="Output", language="json") with gr.Tab("Audio"): aud_in = gr.Audio(label="Upload/Record Audio", type="filepath") run_aud = gr.Button("Run ASR Inference") aud_out = gr.Code(label="Output", language="json") # Wire events load_btn.click(fn=do_load, inputs=model_id_box, outputs=status) run_text.click(fn=infer_text, inputs=[text_in, max_new_tokens, temperature, top_p], outputs=text_out) run_img.click(fn=infer_image, inputs=img_in, outputs=img_out) run_aud.click(fn=infer_audio, inputs=aud_in, outputs=aud_out) demo.queue().launch()