marvinisjarvis commited on
Commit
a1b67ef
·
verified ·
1 Parent(s): c59708b

Update app.py

Browse files

This version fixes the deprecated token argument and the InferenceClient call.

Files changed (1) hide show
  1. app.py +216 -223
app.py CHANGED
@@ -1,223 +1,216 @@
1
- # app.py
2
- """
3
- Gradio app that:
4
- - Uses a local model if torch is installed,
5
- - Otherwise tries Hugging Face InferenceClient,
6
- - Otherwise falls back to legacy InferenceApi with task="text-generation".
7
- Make sure HF_TOKEN is set in Space secrets if your model is private.
8
- """
9
-
10
- import os
11
- from typing import Optional
12
- import gradio as gr
13
-
14
- MODEL_ID = "marvinisjarvis/radio_model"
15
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
-
17
- # Flags & clients
18
- LOCAL_AVAILABLE = False
19
- INFERENCE_CLIENT_AVAILABLE = False
20
- INFERENCE_API_AVAILABLE = False
21
-
22
- # Attempt local loading (torch + transformers)
23
- try:
24
- import torch
25
- from transformers import AutoTokenizer, AutoModelForCausalLM
26
-
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
-
29
- def try_load_local():
30
- print("Attempting to load local model...")
31
- tokenizer = AutoTokenizer.from_pretrained(
32
- MODEL_ID, trust_remote_code=True, use_fast=True, use_auth_token=HF_TOKEN
33
- )
34
- kwargs = {"trust_remote_code": True, "use_auth_token": HF_TOKEN, "low_cpu_mem_usage": True}
35
- if device == "cuda":
36
- kwargs.update({"device_map": "auto", "torch_dtype": torch.float16})
37
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
38
- return tokenizer, model
39
-
40
- try:
41
- tokenizer, model = try_load_local()
42
- LOCAL_AVAILABLE = True
43
- print("Local model loaded.")
44
- except Exception as e:
45
- print("Local model load failed:", e)
46
- LOCAL_AVAILABLE = False
47
-
48
- except Exception as e:
49
- print("Torch not available or failed to import:", e)
50
- LOCAL_AVAILABLE = False
51
-
52
-
53
- # Attempt to use InferenceClient (preferred)
54
- if not LOCAL_AVAILABLE:
55
- try:
56
- from huggingface_hub import InferenceClient
57
- client = InferenceClient(token=HF_TOKEN)
58
- INFERENCE_CLIENT_AVAILABLE = True
59
- print("InferenceClient available - will use remote text-generation via InferenceClient.")
60
- except Exception as e:
61
- print("InferenceClient not available:", e)
62
- INFERENCE_CLIENT_AVAILABLE = False
63
-
64
- # Fallback to legacy InferenceApi with explicit task
65
- inference_api = None
66
- if (not LOCAL_AVAILABLE) and (not INFERENCE_CLIENT_AVAILABLE):
67
- try:
68
- from huggingface_hub import InferenceApi
69
- # Explicitly specify task to avoid "Task not specified" errors
70
- inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_TOKEN, task="text-generation")
71
- INFERENCE_API_AVAILABLE = True
72
- print("Using legacy InferenceApi with task='text-generation'.")
73
- except Exception as e:
74
- print("Hugging Face InferenceApi not available or failed:", e)
75
- INFERENCE_API_AVAILABLE = False
76
-
77
-
78
- # Generation wrapper that handles all three paths
79
- def generate_answer(
80
- prompt: str,
81
- max_new_tokens: int = 256,
82
- temperature: float = 0.7,
83
- top_p: float = 0.9,
84
- num_beams: int = 1,
85
- stop_token: Optional[str] = None,
86
- ) -> str:
87
- if not prompt or prompt.strip() == "":
88
- return "Please enter a prompt."
89
-
90
- # Local path
91
- if LOCAL_AVAILABLE:
92
- try:
93
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
94
- device0 = next(model.parameters()).device
95
- input_ids = inputs["input_ids"].to(device0)
96
- attention_mask = inputs.get("attention_mask")
97
- if attention_mask is not None:
98
- attention_mask = attention_mask.to(device0)
99
-
100
- gen_kwargs = dict(
101
- input_ids=input_ids,
102
- attention_mask=attention_mask,
103
- max_new_tokens=int(max_new_tokens),
104
- temperature=float(temperature),
105
- top_p=float(top_p),
106
- num_beams=int(num_beams),
107
- eos_token_id=getattr(tokenizer, "eos_token_id", None),
108
- pad_token_id=getattr(tokenizer, "pad_token_id", None),
109
- do_sample=(float(temperature) > 0) and (int(num_beams) == 1),
110
- )
111
- outputs = model.generate(**gen_kwargs)
112
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
113
- result = decoded[len(prompt) :].strip() if decoded.startswith(prompt) else decoded.strip()
114
- if stop_token:
115
- idx = result.find(stop_token)
116
- if idx != -1:
117
- result = result[:idx].strip()
118
- return result
119
- except Exception as e:
120
- print("Local generation failed, falling back to remote. Error:", e)
121
-
122
- # InferenceClient path (preferred remote)
123
- if INFERENCE_CLIENT_AVAILABLE:
124
- try:
125
- # The InferenceClient has convenience methods; we call text_generation
126
- response = client.text_generation(
127
- model=MODEL_ID,
128
- inputs=prompt,
129
- parameters={
130
- "max_new_tokens": int(max_new_tokens),
131
- "temperature": float(temperature),
132
- "top_p": float(top_p),
133
- "num_beams": int(num_beams),
134
- },
135
- )
136
- # response may be a list or dict depending on client version
137
- if isinstance(response, dict) and "generated_text" in response:
138
- out = response["generated_text"]
139
- elif isinstance(response, list) and response and isinstance(response[0], dict) and "generated_text" in response[0]:
140
- out = response[0]["generated_text"]
141
- else:
142
- out = str(response)
143
- if out.startswith(prompt):
144
- out = out[len(prompt) :].strip()
145
- if stop_token:
146
- idx = out.find(stop_token)
147
- if idx != -1:
148
- out = out[:idx].strip()
149
- return out.strip()
150
- except Exception as e:
151
- print("InferenceClient call failed, will try legacy InferenceApi. Error:", e)
152
-
153
- # Legacy InferenceApi fallback (explicit task)
154
- if INFERENCE_API_AVAILABLE and inference_api is not None:
155
- try:
156
- params = {"max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p)}
157
- res = inference_api(prompt, params=params)
158
- # normalize response
159
- if isinstance(res, str):
160
- out = res
161
- elif isinstance(res, dict) and "generated_text" in res:
162
- out = res["generated_text"]
163
- elif isinstance(res, list) and res and isinstance(res[0], dict) and "generated_text" in res[0]:
164
- out = res[0]["generated_text"]
165
- else:
166
- out = str(res)
167
- if out.startswith(prompt):
168
- out = out[len(prompt) :].strip()
169
- if stop_token:
170
- idx = out.find(stop_token)
171
- if idx != -1:
172
- out = out[:idx].strip()
173
- return out.strip()
174
- except Exception as e:
175
- print("Legacy InferenceApi failed:", e)
176
- return f"Remote inference failed: {e}"
177
-
178
- return ("No inference path available. Install torch for local inference or ensure HF_TOKEN is set and huggingface_hub supports InferenceClient/InferenceApi.")
179
-
180
-
181
- # --- Gradio UI ---
182
- title = "RadioModel — Radiology Q&A (Mistral 7B fine-tuned)"
183
- description = """
184
- Demo for marvinisjarvis/radio_model.
185
- Tries local inference first; otherwise uses Hugging Face remote inference.
186
- If your model is private, add HF_TOKEN in Space secrets. Not for clinical use.
187
- """
188
-
189
- with gr.Blocks(title=title) as demo:
190
- gr.Markdown(f"## {title}")
191
- gr.Markdown(description)
192
-
193
- with gr.Row():
194
- with gr.Column(scale=3):
195
- prompt_input = gr.Textbox(label="Enter your radiology question", lines=6)
196
- submit = gr.Button("Generate Answer")
197
- examples = gr.Examples(
198
- examples=[
199
- "What does an X-ray of pneumonia typically show?",
200
- "How can you differentiate a benign lung nodule from a malignant one on CT?",
201
- "What are common signs of bone fracture on X-rays?",
202
- "Which imaging modality is best for detecting small brain tumors?"
203
- ],
204
- inputs=prompt_input
205
- )
206
- with gr.Column(scale=2):
207
- max_tokens = gr.Slider(32, 1024, value=256, step=32, label="Max New Tokens")
208
- temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
209
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
210
- num_beams = gr.Slider(1, 5, value=1, step=1, label="Num Beams")
211
- stop_token = gr.Textbox(label="Optional stop token", placeholder="e.g., ### or <END>", lines=1)
212
-
213
- output = gr.Textbox(label="Model output", lines=14)
214
-
215
- def on_submit(prompt, max_new_tokens, temperature, top_p, num_beams, stop_token):
216
- return generate_answer(prompt, max_new_tokens, temperature, top_p, int(num_beams), stop_token)
217
-
218
- submit.click(on_submit, inputs=[prompt_input, max_tokens, temperature, top_p, num_beams, stop_token], outputs=output)
219
-
220
- gr.Markdown("**Disclaimer:** This demo is for evaluation and research. It is not a medical device.")
221
-
222
- if __name__ == "__main__":
223
- demo.launch()
 
1
+ """
2
+ Gradio app that:
3
+ - Uses a local model if torch is installed,
4
+ - Otherwise tries Hugging Face InferenceClient,
5
+ - Otherwise falls back to legacy InferenceApi with task="text-generation".
6
+ Make sure HF_TOKEN is set in Space secrets if your model is private.
7
+ """
8
+
9
+ import os
10
+ from typing import Optional
11
+ import gradio as gr
12
+
13
+ MODEL_ID = "marvinisjarvis/radio_model"
14
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
+
16
+ # Flags & clients
17
+ LOCAL_AVAILABLE = False
18
+ INFERENCE_CLIENT_AVAILABLE = False
19
+ INFERENCE_API_AVAILABLE = False
20
+
21
+ # Attempt local loading (torch + transformers)
22
+ try:
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModelForCausalLM
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ def try_load_local():
29
+ print("Attempting to load local model...")
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ MODEL_ID, trust_remote_code=True, use_fast=True, token=HF_TOKEN
32
+ )
33
+ kwargs = {"trust_remote_code": True, "token": HF_TOKEN, "low_cpu_mem_usage": True}
34
+ if device == "cuda":
35
+ kwargs.update({"device_map": "auto", "torch_dtype": torch.float16})
36
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
37
+ return tokenizer, model
38
+
39
+ try:
40
+ tokenizer, model = try_load_local()
41
+ LOCAL_AVAILABLE = True
42
+ print("Local model loaded.")
43
+ except Exception as e:
44
+ print("Local model load failed:", e)
45
+ LOCAL_AVAILABLE = False
46
+
47
+ except Exception as e:
48
+ print("Torch not available or failed to import:", e)
49
+ LOCAL_AVAILABLE = False
50
+
51
+
52
+ # Attempt to use InferenceClient (preferred)
53
+ if not LOCAL_AVAILABLE:
54
+ try:
55
+ from huggingface_hub import InferenceClient
56
+ client = InferenceClient(token=HF_TOKEN)
57
+ INFERENCE_CLIENT_AVAILABLE = True
58
+ print("InferenceClient available - will use remote text-generation via InferenceClient.")
59
+ except Exception as e:
60
+ print("InferenceClient not available:", e)
61
+ INFERENCE_CLIENT_AVAILABLE = False
62
+
63
+ # Fallback to legacy InferenceApi with explicit task
64
+ inference_api = None
65
+ if (not LOCAL_AVAILABLE) and (not INFERENCE_CLIENT_AVAILABLE):
66
+ try:
67
+ from huggingface_hub import InferenceApi
68
+ # Explicitly specify task to avoid "Task not specified" errors
69
+ inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_TOKEN, task="text-generation")
70
+ INFERENCE_API_AVAILABLE = True
71
+ print("Using legacy InferenceApi with task='text-generation'.")
72
+ except Exception as e:
73
+ print("Hugging Face InferenceApi not available or failed:", e)
74
+ INFERENCE_API_AVAILABLE = False
75
+
76
+
77
+ # Generation wrapper that handles all three paths
78
+ def generate_answer(
79
+ prompt: str,
80
+ max_new_tokens: int = 256,
81
+ temperature: float = 0.7,
82
+ top_p: float = 0.9,
83
+ num_beams: int = 1,
84
+ stop_token: Optional[str] = None,
85
+ ) -> str:
86
+ if not prompt or prompt.strip() == "":
87
+ return "Please enter a prompt."
88
+
89
+ # Local path
90
+ if LOCAL_AVAILABLE:
91
+ try:
92
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
93
+ device0 = next(model.parameters()).device
94
+ input_ids = inputs["input_ids"].to(device0)
95
+ attention_mask = inputs.get("attention_mask")
96
+ if attention_mask is not None:
97
+ attention_mask = attention_mask.to(device0)
98
+
99
+ gen_kwargs = dict(
100
+ input_ids=input_ids,
101
+ attention_mask=attention_mask,
102
+ max_new_tokens=int(max_new_tokens),
103
+ temperature=float(temperature),
104
+ top_p=float(top_p),
105
+ num_beams=int(num_beams),
106
+ eos_token_id=getattr(tokenizer, "eos_token_id", None),
107
+ pad_token_id=getattr(tokenizer, "pad_token_id", None),
108
+ do_sample=(float(temperature) > 0) and (int(num_beams) == 1),
109
+ )
110
+ outputs = model.generate(**gen_kwargs)
111
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
112
+ result = decoded[len(prompt) :].strip() if decoded.startswith(prompt) else decoded.strip()
113
+ if stop_token:
114
+ idx = result.find(stop_token)
115
+ if idx != -1:
116
+ result = result[:idx].strip()
117
+ return result
118
+ except Exception as e:
119
+ print("Local generation failed, falling back to remote. Error:", e)
120
+
121
+ # InferenceClient path (preferred remote)
122
+ if INFERENCE_CLIENT_AVAILABLE:
123
+ try:
124
+ # The InferenceClient text_generation method takes kwargs for parameters
125
+ response = client.text_generation(
126
+ model=MODEL_ID,
127
+ prompt=prompt,
128
+ max_new_tokens=int(max_new_tokens),
129
+ temperature=float(temperature),
130
+ top_p=float(top_p),
131
+ num_beams=int(num_beams),
132
+ )
133
+ # The response is the generated text string
134
+ out = response
135
+ if out.startswith(prompt):
136
+ out = out[len(prompt) :].strip()
137
+ if stop_token:
138
+ idx = out.find(stop_token)
139
+ if idx != -1:
140
+ out = out[:idx].strip()
141
+ return out.strip()
142
+ except Exception as e:
143
+ print("InferenceClient call failed, will try legacy InferenceApi. Error:", e)
144
+
145
+ # Legacy InferenceApi fallback (explicit task)
146
+ if INFERENCE_API_AVAILABLE and inference_api is not None:
147
+ try:
148
+ params = {"max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p)}
149
+ res = inference_api(prompt, params=params)
150
+ # normalize response
151
+ if isinstance(res, str):
152
+ out = res
153
+ elif isinstance(res, dict) and "generated_text" in res:
154
+ out = res["generated_text"]
155
+ elif isinstance(res, list) and res and isinstance(res[0], dict) and "generated_text" in res[0]:
156
+ out = res[0]["generated_text"]
157
+ else:
158
+ out = str(res)
159
+ if out.startswith(prompt):
160
+ out = out[len(prompt) :].strip()
161
+ if stop_token:
162
+ idx = out.find(stop_token)
163
+ if idx != -1:
164
+ out = out[:idx].strip()
165
+ return out.strip()
166
+ except Exception as e:
167
+ print("Legacy InferenceApi failed:", e)
168
+ return f"Remote inference failed: {e}"
169
+
170
+ return ("No inference path available. Install torch for local inference or ensure HF_TOKEN is set and huggingface_hub
171
+ supports InferenceClient/InferenceApi.")
172
+
173
+
174
+ # --- Gradio UI ---
175
+ title = "RadioModel — Radiology Q&A (Mistral 7B fine-tuned)"
176
+ description = """
177
+ Demo for marvinisjarvis/radio_model.
178
+ Tries local inference first; otherwise uses Hugging Face remote inference.
179
+ If your model is private, add HF_TOKEN in Space secrets. Not for clinical use.
180
+ """
181
+
182
+ with gr.Blocks(title=title) as demo:
183
+ gr.Markdown(f"## {title}")
184
+ gr.Markdown(description)
185
+
186
+ with gr.Row():
187
+ with gr.Column(scale=3):
188
+ prompt_input = gr.Textbox(label="Enter your radiology question", lines=6)
189
+ submit = gr.Button("Generate Answer")
190
+ examples = gr.Examples(
191
+ examples=[
192
+ "What does an X-ray of pneumonia typically show?",
193
+ "How can you differentiate a benign lung nodule from a malignant one on CT?",
194
+ "What are common signs of bone fracture on X-rays?",
195
+ "Which imaging modality is best for detecting small brain tumors?"
196
+ ],
197
+ inputs=prompt_input
198
+ )
199
+ with gr.Column(scale=2):
200
+ max_tokens = gr.Slider(32, 1024, value=256, step=32, label="Max New Tokens")
201
+ temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
202
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
203
+ num_beams = gr.Slider(1, 5, value=1, step=1, label="Num Beams")
204
+ stop_token = gr.Textbox(label="Optional stop token", placeholder="e.g., ### or <END>", lines=1)
205
+
206
+ output = gr.Textbox(label="Model output", lines=14)
207
+
208
+ def on_submit(prompt, max_new_tokens, temperature, top_p, num_beams, stop_token):
209
+ return generate_answer(prompt, max_new_tokens, temperature, top_p, int(num_beams), stop_token)
210
+
211
+ submit.click(on_submit, inputs=[prompt_input, max_tokens, temperature, top_p, num_beams, stop_token], outputs=output)
212
+
213
+ gr.Markdown("Disclaimer: This demo is for evaluation and research. It is not a medical device.")
214
+
215
+ if __name__ == "__main__":
216
+ demo.launch()