Mohaddz commited on
Commit
e555569
Β·
verified Β·
1 Parent(s): 1250ed8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -118,15 +118,35 @@ else:
118
  print("Skipping TensorFlow model loading because prerequisite Hugging Face models failed to load.")
119
 
120
  # Load parts list from JSON
 
121
  PARTS_LIST_FILE = 'cars117.json'
122
  all_parts = []
123
  if os.path.exists(PARTS_LIST_FILE):
124
  with open(PARTS_LIST_FILE, 'r', encoding='utf-8') as f:
125
  data = json.load(f)
 
126
  all_parts = sorted(list(set(part for entry in data.values() for part in entry.get('replaced_parts', []))))
127
- if dl_model and dl_model.output_shape[-1] != len(all_parts):
128
- print(f"Warning: TensorFlow model output classes ({dl_model.output_shape[-1]}) "
129
- f"does not match number of parts in JSON ({len(all_parts)}). Predictions may be misaligned.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  else:
131
  print(f"Error: Parts list file '{PARTS_LIST_FILE}' not found. Predicted part names will be unavailable.")
132
 
 
118
  print("Skipping TensorFlow model loading because prerequisite Hugging Face models failed to load.")
119
 
120
  # Load parts list from JSON
121
+ # --- (Corrected Code) ---
122
  PARTS_LIST_FILE = 'cars117.json'
123
  all_parts = []
124
  if os.path.exists(PARTS_LIST_FILE):
125
  with open(PARTS_LIST_FILE, 'r', encoding='utf-8') as f:
126
  data = json.load(f)
127
+ # Get the unique, sorted list of parts from the JSON file
128
  all_parts = sorted(list(set(part for entry in data.values() for part in entry.get('replaced_parts', []))))
129
+
130
+ # FIX: Dynamically handle mismatch between model output and parts list
131
+ if dl_model is not None:
132
+ model_output_size = dl_model.output_shape[-1]
133
+ parts_list_size = len(all_parts)
134
+
135
+ if model_output_size != parts_list_size:
136
+ print(f"WARNING: Model output size ({model_output_size}) and parts list size ({parts_list_size}) do not match.")
137
+
138
+ # If the model expects more outputs, pad the parts list with dummy values
139
+ if model_output_size > parts_list_size:
140
+ diff = model_output_size - parts_list_size
141
+ print(f"Padding the parts list with {diff} dummy entries to prevent a crash.")
142
+ for i in range(diff):
143
+ all_parts.append(f"_dummy_part_{i+1}_")
144
+
145
+ # If the model expects fewer outputs, truncate the parts list
146
+ else:
147
+ diff = parts_list_size - model_output_size
148
+ print(f"Truncating the parts list by {diff} entries to match the model's output.")
149
+ all_parts = all_parts[:model_output_size]
150
  else:
151
  print(f"Error: Parts list file '{PARTS_LIST_FILE}' not found. Predicted part names will be unavailable.")
152