fffiloni commited on
Commit
1de55b1
·
verified ·
1 Parent(s): 4cab197

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import re
4
  import os
5
  import json
 
6
  hf_token = os.environ.get('HF_TOKEN')
7
 
8
  from gradio_client import Client, handle_file
@@ -181,29 +182,32 @@ def parse_perfume_description(text: str) -> dict:
181
 
182
  return result
183
 
184
- def extract_image_description(data: dict) -> str:
185
  """
186
- Extracts the Image Description value from a Python dict response.
187
 
188
  Args:
189
- data (dict): The dict returned by your LLM.
 
190
 
191
  Returns:
192
- str: The Image Description text.
193
  """
 
 
 
 
 
 
194
  if not isinstance(data, dict):
195
- raise TypeError("Input must be a dict")
196
 
197
- image_description = (
198
- data.get("Image Description")
199
- or data.get("image_description")
200
- or None
201
- )
202
 
203
- if not image_description:
204
- raise KeyError("No 'Image Description' field found in the response")
205
 
206
- return image_description.strip()
207
 
208
 
209
  def get_text_after_colon(input_text):
@@ -240,7 +244,8 @@ def infer(image_input):
240
 
241
  parsed = parse_perfume_description(result)
242
 
243
- image_desc = extract_image_description(parsed)
 
244
  print(image_desc)
245
 
246
  return result, parsed
 
3
  import re
4
  import os
5
  import json
6
+ from typing import Union
7
  hf_token = os.environ.get('HF_TOKEN')
8
 
9
  from gradio_client import Client, handle_file
 
182
 
183
  return result
184
 
185
+ def extract_field(data: Union[str, dict], field_name: str) -> str:
186
  """
187
+ Extracts a specific field value from a JSON string or Python dict.
188
 
189
  Args:
190
+ data (Union[str, dict]): The JSON string or dict to extract from.
191
+ field_name (str): The exact field name to extract.
192
 
193
  Returns:
194
+ str: The extracted field value as a string.
195
  """
196
+ if isinstance(data, str):
197
+ try:
198
+ data = json.loads(data)
199
+ except json.JSONDecodeError:
200
+ raise ValueError("Invalid JSON string provided")
201
+
202
  if not isinstance(data, dict):
203
+ raise TypeError("Input must be a dict or a valid JSON string")
204
 
205
+ value = data.get(field_name) or data.get(field_name.lower()) or None
 
 
 
 
206
 
207
+ if value is None:
208
+ raise KeyError(f"No field named '{field_name}' found in the data")
209
 
210
+ return str(value).strip()
211
 
212
 
213
  def get_text_after_colon(input_text):
 
244
 
245
  parsed = parse_perfume_description(result)
246
 
247
+ image_desc = extract_field(parsed, "Image Description")
248
+
249
  print(image_desc)
250
 
251
  return result, parsed