multimodalart HF Staff commited on
Commit
c301058
·
verified ·
1 Parent(s): a8e6bb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -65,8 +65,6 @@ def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[di
65
 
66
  @spaces.GPU(duration=120)
67
  def run_inference_localization(
68
- current_model: AutoModelForImageTextToText,
69
- current_processor: AutoProcessor,
70
  messages_for_template: List[dict[str, Any]],
71
  pil_image_for_processing: Image.Image
72
  ) -> str:
@@ -80,7 +78,7 @@ def run_inference_localization(
80
  """
81
  # 1. Apply chat template to messages. This will create the text part of the prompt,
82
  # including image tags if the image was part of `messages_for_template`.
83
- text_prompt = current_processor.apply_chat_template(
84
  messages_for_template,
85
  tokenize=False,
86
  add_generation_prompt=True
@@ -93,11 +91,11 @@ def run_inference_localization(
93
  padding=True,
94
  return_tensors="pt",
95
  )
96
- inputs = inputs.to(current_model.device)
97
 
98
  # 3. Generate response
99
  # Using do_sample=False for more deterministic output, as in the model card's structured output example
100
- generated_ids = current_model.generate(**inputs, max_new_tokens=128, do_sample=False)
101
 
102
  # 4. Trim input_ids from generated_ids to get only the generated part
103
  generated_ids_trimmed = [
@@ -105,7 +103,7 @@ def run_inference_localization(
105
  ]
106
 
107
  # 5. Decode the generated tokens
108
- decoded_output = current_processor.batch_decode(
109
  generated_ids_trimmed,
110
  skip_special_tokens=True,
111
  clean_up_tokenization_spaces=False
@@ -152,7 +150,7 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
152
  # Pass `messages` (which includes the image object for template processing)
153
  # and `resized_image` (for actual tensor conversion).
154
  try:
155
- coordinates_str = run_inference_localization(model, processor, messages, resized_image)
156
  except Exception as e:
157
  print(f"Error during model inference: {e}")
158
  return f"Error during model inference: {e}", resized_image.copy().convert("RGB")
 
65
 
66
  @spaces.GPU(duration=120)
67
  def run_inference_localization(
 
 
68
  messages_for_template: List[dict[str, Any]],
69
  pil_image_for_processing: Image.Image
70
  ) -> str:
 
78
  """
79
  # 1. Apply chat template to messages. This will create the text part of the prompt,
80
  # including image tags if the image was part of `messages_for_template`.
81
+ text_prompt = processor.apply_chat_template(
82
  messages_for_template,
83
  tokenize=False,
84
  add_generation_prompt=True
 
91
  padding=True,
92
  return_tensors="pt",
93
  )
94
+ inputs = inputs.to(model.device)
95
 
96
  # 3. Generate response
97
  # Using do_sample=False for more deterministic output, as in the model card's structured output example
98
+ generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
99
 
100
  # 4. Trim input_ids from generated_ids to get only the generated part
101
  generated_ids_trimmed = [
 
103
  ]
104
 
105
  # 5. Decode the generated tokens
106
+ decoded_output = processor.batch_decode(
107
  generated_ids_trimmed,
108
  skip_special_tokens=True,
109
  clean_up_tokenization_spaces=False
 
150
  # Pass `messages` (which includes the image object for template processing)
151
  # and `resized_image` (for actual tensor conversion).
152
  try:
153
+ coordinates_str = run_inference_localization(messages, resized_image)
154
  except Exception as e:
155
  print(f"Error during model inference: {e}")
156
  return f"Error during model inference: {e}", resized_image.copy().convert("RGB")