File size: 9,050 Bytes
a1b67ef
 
 
 
 
 
 
 
 
 
 
bd4ae36
a1b67ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
  """
  Gradio app that:
  - Uses a local model if torch is installed,
  - Otherwise tries Hugging Face InferenceClient,
  - Otherwise falls back to legacy InferenceApi with task="text-generation".
  Make sure HF_TOKEN is set in Space secrets if your model is private.
  """

  import os
  from typing import Optional
  import gradio as gr
import torch

  MODEL_ID = "marvinisjarvis/radio_model"
  HF_TOKEN = os.environ.get("HF_TOKEN", None)

  # Flags & clients
  LOCAL_AVAILABLE = False
  INFERENCE_CLIENT_AVAILABLE = False
  INFERENCE_API_AVAILABLE = False

  # Attempt local loading (torch + transformers)
  try:
      import torch
      from transformers import AutoTokenizer, AutoModelForCausalLM

      device = "cuda" if torch.cuda.is_available() else "cpu"

      def try_load_local():
          print("Attempting to load local model...")
          tokenizer = AutoTokenizer.from_pretrained(
              MODEL_ID, trust_remote_code=True, use_fast=True, token=HF_TOKEN
          )
          kwargs = {"trust_remote_code": True, "token": HF_TOKEN, "low_cpu_mem_usage": True}
          if device == "cuda":
              kwargs.update({"device_map": "auto", "torch_dtype": torch.float16})
          model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
          return tokenizer, model

      try:
          tokenizer, model = try_load_local()
          LOCAL_AVAILABLE = True
          print("Local model loaded.")
      except Exception as e:
          print("Local model load failed:", e)
          LOCAL_AVAILABLE = False

  except Exception as e:
      print("Torch not available or failed to import:", e)
      LOCAL_AVAILABLE = False


  # Attempt to use InferenceClient (preferred)
  if not LOCAL_AVAILABLE:
      try:
          from huggingface_hub import InferenceClient
          client = InferenceClient(token=HF_TOKEN)
          INFERENCE_CLIENT_AVAILABLE = True
          print("InferenceClient available - will use remote text-generation via InferenceClient.")
      except Exception as e:
          print("InferenceClient not available:", e)
          INFERENCE_CLIENT_AVAILABLE = False

  # Fallback to legacy InferenceApi with explicit task
  inference_api = None
  if (not LOCAL_AVAILABLE) and (not INFERENCE_CLIENT_AVAILABLE):
      try:
          from huggingface_hub import InferenceApi
          # Explicitly specify task to avoid "Task not specified" errors
          inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_TOKEN, task="text-generation")
          INFERENCE_API_AVAILABLE = True
          print("Using legacy InferenceApi with task='text-generation'.")
      except Exception as e:
          print("Hugging Face InferenceApi not available or failed:", e)
          INFERENCE_API_AVAILABLE = False


  # Generation wrapper that handles all three paths
  def generate_answer(
      prompt: str,
      max_new_tokens: int = 256,
      temperature: float = 0.7,
      top_p: float = 0.9,
      num_beams: int = 1,
      stop_token: Optional[str] = None,
  ) -> str:
      if not prompt or prompt.strip() == "":
          return "Please enter a prompt."

      # Local path
      if LOCAL_AVAILABLE:
          try:
              inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
              device0 = next(model.parameters()).device
              input_ids = inputs["input_ids"].to(device0)
              attention_mask = inputs.get("attention_mask")
              if attention_mask is not None:
                  attention_mask = attention_mask.to(device0)

              gen_kwargs = dict(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  max_new_tokens=int(max_new_tokens),
                  temperature=float(temperature),
                  top_p=float(top_p),
                  num_beams=int(num_beams),
                  eos_token_id=getattr(tokenizer, "eos_token_id", None),
                  pad_token_id=getattr(tokenizer, "pad_token_id", None),
                  do_sample=(float(temperature) > 0) and (int(num_beams) == 1),
              )
              outputs = model.generate(**gen_kwargs)
              decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
              result = decoded[len(prompt) :].strip() if decoded.startswith(prompt) else decoded.strip()
              if stop_token:
                  idx = result.find(stop_token)
                  if idx != -1:
                      result = result[:idx].strip()
              return result
          except Exception as e:
              print("Local generation failed, falling back to remote. Error:", e)

      # InferenceClient path (preferred remote)
      if INFERENCE_CLIENT_AVAILABLE:
          try:
              # The InferenceClient text_generation method takes kwargs for parameters
              response = client.text_generation(
                  model=MODEL_ID,
                  prompt=prompt,
                  max_new_tokens=int(max_new_tokens),
                  temperature=float(temperature),
                  top_p=float(top_p),
                  num_beams=int(num_beams),
              )
              # The response is the generated text string
              out = response
              if out.startswith(prompt):
                  out = out[len(prompt) :].strip()
              if stop_token:
                  idx = out.find(stop_token)
                  if idx != -1:
                      out = out[:idx].strip()
              return out.strip()
          except Exception as e:
              print("InferenceClient call failed, will try legacy InferenceApi. Error:", e)

      # Legacy InferenceApi fallback (explicit task)
      if INFERENCE_API_AVAILABLE and inference_api is not None:
          try:
              params = {"max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p)}
              res = inference_api(prompt, params=params)
              # normalize response
              if isinstance(res, str):
                  out = res
              elif isinstance(res, dict) and "generated_text" in res:
                  out = res["generated_text"]
              elif isinstance(res, list) and res and isinstance(res[0], dict) and "generated_text" in res[0]:
                  out = res[0]["generated_text"]
              else:
                  out = str(res)
              if out.startswith(prompt):
                  out = out[len(prompt) :].strip()
              if stop_token:
                  idx = out.find(stop_token)
                  if idx != -1:
                      out = out[:idx].strip()
              return out.strip()
          except Exception as e:
              print("Legacy InferenceApi failed:", e)
              return f"Remote inference failed: {e}"

      return ("No inference path available. Install torch for local inference or ensure HF_TOKEN is set and huggingface_hub
  supports InferenceClient/InferenceApi.")


  # --- Gradio UI ---
  title = "RadioModel — Radiology Q&A (Mistral 7B fine-tuned)"
  description = """
  Demo for marvinisjarvis/radio_model.
  Tries local inference first; otherwise uses Hugging Face remote inference.
  If your model is private, add HF_TOKEN in Space secrets. Not for clinical use.
  """

  with gr.Blocks(title=title) as demo:
      gr.Markdown(f"## {title}")
      gr.Markdown(description)

      with gr.Row():
          with gr.Column(scale=3):
              prompt_input = gr.Textbox(label="Enter your radiology question", lines=6)
              submit = gr.Button("Generate Answer")
              examples = gr.Examples(
                  examples=[
                      "What does an X-ray of pneumonia typically show?",
                      "How can you differentiate a benign lung nodule from a malignant one on CT?",
                      "What are common signs of bone fracture on X-rays?",
                      "Which imaging modality is best for detecting small brain tumors?"
                  ],
                  inputs=prompt_input
              )
          with gr.Column(scale=2):
              max_tokens = gr.Slider(32, 1024, value=256, step=32, label="Max New Tokens")
              temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
              top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
              num_beams = gr.Slider(1, 5, value=1, step=1, label="Num Beams")
              stop_token = gr.Textbox(label="Optional stop token", placeholder="e.g., ### or <END>", lines=1)

      output = gr.Textbox(label="Model output", lines=14)

      def on_submit(prompt, max_new_tokens, temperature, top_p, num_beams, stop_token):
          return generate_answer(prompt, max_new_tokens, temperature, top_p, int(num_beams), stop_token)

      submit.click(on_submit, inputs=[prompt_input, max_tokens, temperature, top_p, num_beams, stop_token], outputs=output)

      gr.Markdown("Disclaimer: This demo is for evaluation and research. It is not a medical device.")

  if __name__ == "__main__":
      demo.launch()