Spaces:
Runtime error
Runtime error
""" | |
Gradio app that: | |
- Uses a local model if torch is installed, | |
- Otherwise tries Hugging Face InferenceClient, | |
- Otherwise falls back to legacy InferenceApi with task="text-generation". | |
Make sure HF_TOKEN is set in Space secrets if your model is private. | |
""" | |
import os | |
from typing import Optional | |
import gradio as gr | |
import torch | |
MODEL_ID = "marvinisjarvis/radio_model" | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# Flags & clients | |
LOCAL_AVAILABLE = False | |
INFERENCE_CLIENT_AVAILABLE = False | |
INFERENCE_API_AVAILABLE = False | |
# Attempt local loading (torch + transformers) | |
try: | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def try_load_local(): | |
print("Attempting to load local model...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_ID, trust_remote_code=True, use_fast=True, token=HF_TOKEN | |
) | |
kwargs = {"trust_remote_code": True, "token": HF_TOKEN, "low_cpu_mem_usage": True} | |
if device == "cuda": | |
kwargs.update({"device_map": "auto", "torch_dtype": torch.float16}) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) | |
return tokenizer, model | |
try: | |
tokenizer, model = try_load_local() | |
LOCAL_AVAILABLE = True | |
print("Local model loaded.") | |
except Exception as e: | |
print("Local model load failed:", e) | |
LOCAL_AVAILABLE = False | |
except Exception as e: | |
print("Torch not available or failed to import:", e) | |
LOCAL_AVAILABLE = False | |
# Attempt to use InferenceClient (preferred) | |
if not LOCAL_AVAILABLE: | |
try: | |
from huggingface_hub import InferenceClient | |
client = InferenceClient(token=HF_TOKEN) | |
INFERENCE_CLIENT_AVAILABLE = True | |
print("InferenceClient available - will use remote text-generation via InferenceClient.") | |
except Exception as e: | |
print("InferenceClient not available:", e) | |
INFERENCE_CLIENT_AVAILABLE = False | |
# Fallback to legacy InferenceApi with explicit task | |
inference_api = None | |
if (not LOCAL_AVAILABLE) and (not INFERENCE_CLIENT_AVAILABLE): | |
try: | |
from huggingface_hub import InferenceApi | |
# Explicitly specify task to avoid "Task not specified" errors | |
inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_TOKEN, task="text-generation") | |
INFERENCE_API_AVAILABLE = True | |
print("Using legacy InferenceApi with task='text-generation'.") | |
except Exception as e: | |
print("Hugging Face InferenceApi not available or failed:", e) | |
INFERENCE_API_AVAILABLE = False | |
# Generation wrapper that handles all three paths | |
def generate_answer( | |
prompt: str, | |
max_new_tokens: int = 256, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
num_beams: int = 1, | |
stop_token: Optional[str] = None, | |
) -> str: | |
if not prompt or prompt.strip() == "": | |
return "Please enter a prompt." | |
# Local path | |
if LOCAL_AVAILABLE: | |
try: | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
device0 = next(model.parameters()).device | |
input_ids = inputs["input_ids"].to(device0) | |
attention_mask = inputs.get("attention_mask") | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(device0) | |
gen_kwargs = dict( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=int(max_new_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
num_beams=int(num_beams), | |
eos_token_id=getattr(tokenizer, "eos_token_id", None), | |
pad_token_id=getattr(tokenizer, "pad_token_id", None), | |
do_sample=(float(temperature) > 0) and (int(num_beams) == 1), | |
) | |
outputs = model.generate(**gen_kwargs) | |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
result = decoded[len(prompt) :].strip() if decoded.startswith(prompt) else decoded.strip() | |
if stop_token: | |
idx = result.find(stop_token) | |
if idx != -1: | |
result = result[:idx].strip() | |
return result | |
except Exception as e: | |
print("Local generation failed, falling back to remote. Error:", e) | |
# InferenceClient path (preferred remote) | |
if INFERENCE_CLIENT_AVAILABLE: | |
try: | |
# The InferenceClient text_generation method takes kwargs for parameters | |
response = client.text_generation( | |
model=MODEL_ID, | |
prompt=prompt, | |
max_new_tokens=int(max_new_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
num_beams=int(num_beams), | |
) | |
# The response is the generated text string | |
out = response | |
if out.startswith(prompt): | |
out = out[len(prompt) :].strip() | |
if stop_token: | |
idx = out.find(stop_token) | |
if idx != -1: | |
out = out[:idx].strip() | |
return out.strip() | |
except Exception as e: | |
print("InferenceClient call failed, will try legacy InferenceApi. Error:", e) | |
# Legacy InferenceApi fallback (explicit task) | |
if INFERENCE_API_AVAILABLE and inference_api is not None: | |
try: | |
params = {"max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p)} | |
res = inference_api(prompt, params=params) | |
# normalize response | |
if isinstance(res, str): | |
out = res | |
elif isinstance(res, dict) and "generated_text" in res: | |
out = res["generated_text"] | |
elif isinstance(res, list) and res and isinstance(res[0], dict) and "generated_text" in res[0]: | |
out = res[0]["generated_text"] | |
else: | |
out = str(res) | |
if out.startswith(prompt): | |
out = out[len(prompt) :].strip() | |
if stop_token: | |
idx = out.find(stop_token) | |
if idx != -1: | |
out = out[:idx].strip() | |
return out.strip() | |
except Exception as e: | |
print("Legacy InferenceApi failed:", e) | |
return f"Remote inference failed: {e}" | |
return ("No inference path available. Install torch for local inference or ensure HF_TOKEN is set and huggingface_hub | |
supports InferenceClient/InferenceApi.") | |
# --- Gradio UI --- | |
title = "RadioModel — Radiology Q&A (Mistral 7B fine-tuned)" | |
description = """ | |
Demo for marvinisjarvis/radio_model. | |
Tries local inference first; otherwise uses Hugging Face remote inference. | |
If your model is private, add HF_TOKEN in Space secrets. Not for clinical use. | |
""" | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt_input = gr.Textbox(label="Enter your radiology question", lines=6) | |
submit = gr.Button("Generate Answer") | |
examples = gr.Examples( | |
examples=[ | |
"What does an X-ray of pneumonia typically show?", | |
"How can you differentiate a benign lung nodule from a malignant one on CT?", | |
"What are common signs of bone fracture on X-rays?", | |
"Which imaging modality is best for detecting small brain tumors?" | |
], | |
inputs=prompt_input | |
) | |
with gr.Column(scale=2): | |
max_tokens = gr.Slider(32, 1024, value=256, step=32, label="Max New Tokens") | |
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p") | |
num_beams = gr.Slider(1, 5, value=1, step=1, label="Num Beams") | |
stop_token = gr.Textbox(label="Optional stop token", placeholder="e.g., ### or <END>", lines=1) | |
output = gr.Textbox(label="Model output", lines=14) | |
def on_submit(prompt, max_new_tokens, temperature, top_p, num_beams, stop_token): | |
return generate_answer(prompt, max_new_tokens, temperature, top_p, int(num_beams), stop_token) | |
submit.click(on_submit, inputs=[prompt_input, max_tokens, temperature, top_p, num_beams, stop_token], outputs=output) | |
gr.Markdown("Disclaimer: This demo is for evaluation and research. It is not a medical device.") | |
if __name__ == "__main__": | |
demo.launch() |