mssaidat commited on
Commit
cdd5fed
·
verified ·
1 Parent(s): 5f1ac8e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from huggingface_hub import get_model_info, HfApi, hf_hub_url
7
+ from transformers import pipeline, Pipeline
8
+
9
+ # --- Config ---
10
+ # You gave a Space URL, so we'll assume your *model* lives at "mssaidat/Radiologist".
11
+ # If your actual model id is different, either:
12
+ # 1) change DEFAULT_MODEL_ID below, or
13
+ # 2) type the correct id in the UI and click "Load / Reload".
14
+ DEFAULT_MODEL_ID = os.getenv("MODEL_ID", "mssaidat/Radiologist")
15
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
16
+
17
+ # --- Globals ---
18
+ pl: Optional[Pipeline] = None
19
+ current_task: Optional[str] = None
20
+ current_model_id: str = DEFAULT_MODEL_ID
21
+
22
+
23
+ # --- Helpers ---
24
+ def _pretty(obj: Any) -> str:
25
+ try:
26
+ return json.dumps(obj, indent=2, ensure_ascii=False)
27
+ except Exception:
28
+ return str(obj)
29
+
30
+ def detect_task(model_id: str) -> str:
31
+ """
32
+ Uses the model's Hub config to determine its pipeline task.
33
+ """
34
+ info = get_model_info(model_id, token=HF_TOKEN)
35
+ # Preferred: pipeline_tag; Fallback: tags
36
+ if info.pipeline_tag:
37
+ return info.pipeline_tag
38
+ # Rare fallback if pipeline_tag missing:
39
+ tags = set(info.tags or [])
40
+ # crude heuristics
41
+ if "text-generation" in tags or "causal-lm" in tags:
42
+ return "text-generation"
43
+ if "text2text-generation" in tags or "seq2seq" in tags:
44
+ return "text2text-generation"
45
+ if "fill-mask" in tags:
46
+ return "fill-mask"
47
+ if "token-classification" in tags:
48
+ return "token-classification"
49
+ if "text-classification" in tags or "sentiment-analysis" in tags:
50
+ return "text-classification"
51
+ if "question-answering" in tags:
52
+ return "question-answering"
53
+ if "image-classification" in tags:
54
+ return "image-classification"
55
+ if "automatic-speech-recognition" in tags or "asr" in tags:
56
+ return "automatic-speech-recognition"
57
+ # Last resort
58
+ return "text-generation"
59
+
60
+ SUPPORTED = {
61
+ # text inputs
62
+ "text-generation",
63
+ "text2text-generation",
64
+ "fill-mask",
65
+ "token-classification",
66
+ "text-classification",
67
+ "question-answering",
68
+ # image input
69
+ "image-classification",
70
+ # audio input
71
+ "automatic-speech-recognition",
72
+ }
73
+
74
+ def load_pipeline(model_id: str):
75
+ global pl, current_task, current_model_id
76
+ task = detect_task(model_id)
77
+ if task not in SUPPORTED:
78
+ raise ValueError(
79
+ f"Detected task '{task}', which this demo doesn't handle yet. "
80
+ f"Supported: {sorted(list(SUPPORTED))}"
81
+ )
82
+ # device_map="auto" to use GPU if available in the Space
83
+ pl = pipeline(task=task, model=model_id, token=HF_TOKEN, device_map="auto")
84
+ current_task = task
85
+ current_model_id = model_id
86
+ return task
87
+
88
+
89
+ # --- Inference functions (simple, generic) ---
90
+ def infer_text(prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
91
+ if pl is None or current_task is None:
92
+ return "Model not loaded. Click 'Load / Reload' first."
93
+
94
+ if current_task in ["text-generation", "text2text-generation"]:
95
+ out = pl(
96
+ prompt,
97
+ max_new_tokens=max_new_tokens,
98
+ temperature=temperature,
99
+ top_p=top_p,
100
+ do_sample=True
101
+ )
102
+ # pipelines may return list[dict] with 'generated_text' or 'summary_text'
103
+ if isinstance(out, list) and out and "generated_text" in out[0]:
104
+ return out[0]["generated_text"]
105
+ return _pretty(out)
106
+
107
+ elif current_task == "fill-mask":
108
+ out = pl(prompt)
109
+ return _pretty(out)
110
+
111
+ elif current_task == "text-classification":
112
+ out = pl(prompt, top_k=None) # full distribution if supported
113
+ return _pretty(out)
114
+
115
+ elif current_task == "token-classification": # NER
116
+ out = pl(prompt, aggregation_strategy="simple")
117
+ return _pretty(out)
118
+
119
+ elif current_task == "question-answering":
120
+ # Expect "prompt" like: "QUESTION <sep> CONTEXT"
121
+ # Minimal UX: split on first line break or <sep>
122
+ if "<sep>" in prompt:
123
+ q, c = prompt.split("<sep>", 1)
124
+ elif "\n" in prompt:
125
+ q, c = prompt.split("\n", 1)
126
+ else:
127
+ return ("For question-answering, provide input as:\n"
128
+ "question <sep> context\nor\nquestion\\ncontext")
129
+ out = pl(question=q.strip(), context=c.strip())
130
+ return _pretty(out)
131
+
132
+ else:
133
+ return f"Current task '{current_task}' uses a different tab."
134
+
135
+
136
+ def infer_image(image) -> str:
137
+ if pl is None or current_task is None:
138
+ return "Model not loaded. Click 'Load / Reload' first."
139
+ if current_task != "image-classification":
140
+ return f"Loaded task '{current_task}'. Use the appropriate tab."
141
+ out = pl(image)
142
+ return _pretty(out)
143
+
144
+
145
+ def infer_audio(audio) -> str:
146
+ if pl is None or current_task is None:
147
+ return "Model not loaded. Click 'Load / Reload' first."
148
+ if current_task != "automatic-speech-recognition":
149
+ return f"Loaded task '{current_task}'. Use the appropriate tab."
150
+ # gr.Audio returns (sample_rate, data) or a file path depending on type
151
+ out = pl(audio)
152
+ return _pretty(out)
153
+
154
+
155
+ def do_load(model_id: str):
156
+ try:
157
+ task = load_pipeline(model_id.strip())
158
+ msg = f"✅ Loaded '{model_id}' as task: {task}"
159
+ hint = {
160
+ "text-generation": "Use the **Text** tab. Enter a prompt; tweak max_new_tokens/temperature/top_p.",
161
+ "text2text-generation": "Use the **Text** tab for instructions → outputs.",
162
+ "fill-mask": "Use the **Text** tab. Include the [MASK] token in your input.",
163
+ "text-classification": "Use the **Text** tab. Paste text to classify.",
164
+ "token-classification": "Use the **Text** tab. Paste text for NER.",
165
+ "question-answering": "Use the **Text** tab. Format: `question <sep> context` (or line break).",
166
+ "image-classification": "Use the **Image** tab and upload an image.",
167
+ "automatic-speech-recognition": "Use the **Audio** tab and upload/record audio."
168
+ }.get(task, "")
169
+ return msg + ("\n" + hint if hint else "")
170
+ except Exception as e:
171
+ return f"❌ Load failed: {e}"
172
+
173
+
174
+ # --- UI ---
175
+ with gr.Blocks(title="Radiologist — Hugging Face Space", theme=gr.themes.Soft()) as demo:
176
+ gr.Markdown(
177
+ """
178
+ # 🩺 Radiologist — Universal Model Demo
179
+ This Space auto-detects your model's task from the Hub and gives you the right input panel.
180
+
181
+ **How to use**
182
+ 1. Enter your model id (e.g., `mssaidat/Radiologist`) and click **Load / Reload**.
183
+ 2. Use the matching tab (**Text**, **Image**, or **Audio**).
184
+ """
185
+ )
186
+
187
+ with gr.Row():
188
+ model_id_box = gr.Textbox(
189
+ label="Model ID",
190
+ value=DEFAULT_MODEL_ID,
191
+ placeholder="e.g. mssaidat/Radiologist"
192
+ )
193
+ load_btn = gr.Button("Load / Reload", variant="primary")
194
+ status = gr.Markdown("*(No model loaded yet)*")
195
+
196
+ with gr.Tabs():
197
+ with gr.Tab("Text"):
198
+ text_in = gr.Textbox(
199
+ label="Text Input",
200
+ placeholder=(
201
+ "Enter a prompt.\n"
202
+ "For QA models: question <sep> context (or question on first line, context on second)"
203
+ ),
204
+ lines=6
205
+ )
206
+ with gr.Row():
207
+ max_new_tokens = gr.Slider(1, 1024, value=256, step=1, label="max_new_tokens")
208
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
209
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
210
+ run_text = gr.Button("Run Text Inference")
211
+ text_out = gr.Code(label="Output", language="json")
212
+
213
+ with gr.Tab("Image"):
214
+ img_in = gr.Image(label="Upload Image", type="pil")
215
+ run_img = gr.Button("Run Image Inference")
216
+ img_out = gr.Code(label="Output", language="json")
217
+
218
+ with gr.Tab("Audio"):
219
+ aud_in = gr.Audio(label="Upload/Record Audio", type="filepath")
220
+ run_aud = gr.Button("Run ASR Inference")
221
+ aud_out = gr.Code(label="Output", language="json")
222
+
223
+ # Wire events
224
+ load_btn.click(fn=do_load, inputs=model_id_box, outputs=status)
225
+ run_text.click(fn=infer_text, inputs=[text_in, max_new_tokens, temperature, top_p], outputs=text_out)
226
+ run_img.click(fn=infer_image, inputs=img_in, outputs=img_out)
227
+ run_aud.click(fn=infer_audio, inputs=aud_in, outputs=aud_out)
228
+
229
+ demo.queue().launch()