Spaces:
Runtime error
Runtime error
File size: 9,050 Bytes
a1b67ef bd4ae36 a1b67ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
"""
Gradio app that:
- Uses a local model if torch is installed,
- Otherwise tries Hugging Face InferenceClient,
- Otherwise falls back to legacy InferenceApi with task="text-generation".
Make sure HF_TOKEN is set in Space secrets if your model is private.
"""
import os
from typing import Optional
import gradio as gr
import torch
MODEL_ID = "marvinisjarvis/radio_model"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Flags & clients
LOCAL_AVAILABLE = False
INFERENCE_CLIENT_AVAILABLE = False
INFERENCE_API_AVAILABLE = False
# Attempt local loading (torch + transformers)
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
def try_load_local():
print("Attempting to load local model...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True, use_fast=True, token=HF_TOKEN
)
kwargs = {"trust_remote_code": True, "token": HF_TOKEN, "low_cpu_mem_usage": True}
if device == "cuda":
kwargs.update({"device_map": "auto", "torch_dtype": torch.float16})
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
return tokenizer, model
try:
tokenizer, model = try_load_local()
LOCAL_AVAILABLE = True
print("Local model loaded.")
except Exception as e:
print("Local model load failed:", e)
LOCAL_AVAILABLE = False
except Exception as e:
print("Torch not available or failed to import:", e)
LOCAL_AVAILABLE = False
# Attempt to use InferenceClient (preferred)
if not LOCAL_AVAILABLE:
try:
from huggingface_hub import InferenceClient
client = InferenceClient(token=HF_TOKEN)
INFERENCE_CLIENT_AVAILABLE = True
print("InferenceClient available - will use remote text-generation via InferenceClient.")
except Exception as e:
print("InferenceClient not available:", e)
INFERENCE_CLIENT_AVAILABLE = False
# Fallback to legacy InferenceApi with explicit task
inference_api = None
if (not LOCAL_AVAILABLE) and (not INFERENCE_CLIENT_AVAILABLE):
try:
from huggingface_hub import InferenceApi
# Explicitly specify task to avoid "Task not specified" errors
inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_TOKEN, task="text-generation")
INFERENCE_API_AVAILABLE = True
print("Using legacy InferenceApi with task='text-generation'.")
except Exception as e:
print("Hugging Face InferenceApi not available or failed:", e)
INFERENCE_API_AVAILABLE = False
# Generation wrapper that handles all three paths
def generate_answer(
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
num_beams: int = 1,
stop_token: Optional[str] = None,
) -> str:
if not prompt or prompt.strip() == "":
return "Please enter a prompt."
# Local path
if LOCAL_AVAILABLE:
try:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
device0 = next(model.parameters()).device
input_ids = inputs["input_ids"].to(device0)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device0)
gen_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
num_beams=int(num_beams),
eos_token_id=getattr(tokenizer, "eos_token_id", None),
pad_token_id=getattr(tokenizer, "pad_token_id", None),
do_sample=(float(temperature) > 0) and (int(num_beams) == 1),
)
outputs = model.generate(**gen_kwargs)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
result = decoded[len(prompt) :].strip() if decoded.startswith(prompt) else decoded.strip()
if stop_token:
idx = result.find(stop_token)
if idx != -1:
result = result[:idx].strip()
return result
except Exception as e:
print("Local generation failed, falling back to remote. Error:", e)
# InferenceClient path (preferred remote)
if INFERENCE_CLIENT_AVAILABLE:
try:
# The InferenceClient text_generation method takes kwargs for parameters
response = client.text_generation(
model=MODEL_ID,
prompt=prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
num_beams=int(num_beams),
)
# The response is the generated text string
out = response
if out.startswith(prompt):
out = out[len(prompt) :].strip()
if stop_token:
idx = out.find(stop_token)
if idx != -1:
out = out[:idx].strip()
return out.strip()
except Exception as e:
print("InferenceClient call failed, will try legacy InferenceApi. Error:", e)
# Legacy InferenceApi fallback (explicit task)
if INFERENCE_API_AVAILABLE and inference_api is not None:
try:
params = {"max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p)}
res = inference_api(prompt, params=params)
# normalize response
if isinstance(res, str):
out = res
elif isinstance(res, dict) and "generated_text" in res:
out = res["generated_text"]
elif isinstance(res, list) and res and isinstance(res[0], dict) and "generated_text" in res[0]:
out = res[0]["generated_text"]
else:
out = str(res)
if out.startswith(prompt):
out = out[len(prompt) :].strip()
if stop_token:
idx = out.find(stop_token)
if idx != -1:
out = out[:idx].strip()
return out.strip()
except Exception as e:
print("Legacy InferenceApi failed:", e)
return f"Remote inference failed: {e}"
return ("No inference path available. Install torch for local inference or ensure HF_TOKEN is set and huggingface_hub
supports InferenceClient/InferenceApi.")
# --- Gradio UI ---
title = "RadioModel — Radiology Q&A (Mistral 7B fine-tuned)"
description = """
Demo for marvinisjarvis/radio_model.
Tries local inference first; otherwise uses Hugging Face remote inference.
If your model is private, add HF_TOKEN in Space secrets. Not for clinical use.
"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=3):
prompt_input = gr.Textbox(label="Enter your radiology question", lines=6)
submit = gr.Button("Generate Answer")
examples = gr.Examples(
examples=[
"What does an X-ray of pneumonia typically show?",
"How can you differentiate a benign lung nodule from a malignant one on CT?",
"What are common signs of bone fracture on X-rays?",
"Which imaging modality is best for detecting small brain tumors?"
],
inputs=prompt_input
)
with gr.Column(scale=2):
max_tokens = gr.Slider(32, 1024, value=256, step=32, label="Max New Tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
num_beams = gr.Slider(1, 5, value=1, step=1, label="Num Beams")
stop_token = gr.Textbox(label="Optional stop token", placeholder="e.g., ### or <END>", lines=1)
output = gr.Textbox(label="Model output", lines=14)
def on_submit(prompt, max_new_tokens, temperature, top_p, num_beams, stop_token):
return generate_answer(prompt, max_new_tokens, temperature, top_p, int(num_beams), stop_token)
submit.click(on_submit, inputs=[prompt_input, max_tokens, temperature, top_p, num_beams, stop_token], outputs=output)
gr.Markdown("Disclaimer: This demo is for evaluation and research. It is not a medical device.")
if __name__ == "__main__":
demo.launch() |