Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +35 -25
visualization.py
CHANGED
|
@@ -296,35 +296,45 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
| 296 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
| 297 |
return None
|
| 298 |
|
| 299 |
-
def create_heatmap(t,
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
combined_mse[0] = mse_embeddings_norm
|
|
|
|
| 309 |
combined_mse[1] = mse_posture_norm
|
| 310 |
combined_mse[2] = mse_voice_norm
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
return heatmap_img
|
| 328 |
|
| 329 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
| 330 |
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
|
|
|
|
| 296 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
| 297 |
return None
|
| 298 |
|
| 299 |
+
def create_heatmap(t, mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered, fps, total_frames, width):
|
| 300 |
+
# Normalize the MSE values
|
| 301 |
+
mse_embeddings_norm = normalize_mse(mse_embeddings_filtered)
|
| 302 |
+
mse_posture_norm = normalize_mse(mse_posture_filtered)
|
| 303 |
+
mse_voice_norm = normalize_mse(mse_voice_filtered)
|
| 304 |
+
|
| 305 |
+
# Debug prints
|
| 306 |
+
print(f"mse_embeddings_norm shape: {mse_embeddings_norm.shape}")
|
| 307 |
+
print(f"mse_posture_norm shape: {mse_posture_norm.shape}")
|
| 308 |
+
print(f"mse_voice_norm shape: {mse_voice_norm.shape}")
|
| 309 |
+
|
| 310 |
+
# Ensure combined_mse has the correct shape
|
| 311 |
+
combined_mse = np.zeros((total_frames, width))
|
| 312 |
+
|
| 313 |
+
# Adjust shapes and pad with zeros if necessary
|
| 314 |
+
mse_embeddings_norm = pad_or_trim_array(mse_embeddings_norm, width)
|
| 315 |
+
mse_posture_norm = pad_or_trim_array(mse_posture_norm, width)
|
| 316 |
+
mse_voice_norm = pad_or_trim_array(mse_voice_norm, width)
|
| 317 |
+
|
| 318 |
combined_mse[0] = mse_embeddings_norm
|
| 319 |
+
# Assuming you combine posture and voice MSEs similarly
|
| 320 |
combined_mse[1] = mse_posture_norm
|
| 321 |
combined_mse[2] = mse_voice_norm
|
| 322 |
|
| 323 |
+
# Return or use combined_mse as needed
|
| 324 |
+
return combined_mse
|
| 325 |
+
|
| 326 |
+
def normalize_mse(mse):
|
| 327 |
+
# Your normalization logic here
|
| 328 |
+
return mse / np.max(mse)
|
| 329 |
+
|
| 330 |
+
def pad_or_trim_array(arr, target_length):
|
| 331 |
+
if len(arr) > target_length:
|
| 332 |
+
# Trim the array
|
| 333 |
+
return arr[:target_length]
|
| 334 |
+
elif len(arr) < target_length:
|
| 335 |
+
# Pad the array with zeros
|
| 336 |
+
return np.pad(arr, (0, target_length - len(arr)), 'constant')
|
| 337 |
+
return arr
|
|
|
|
| 338 |
|
| 339 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
| 340 |
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
|