Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +10 -8
visualization.py
CHANGED
|
@@ -233,18 +233,20 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
| 233 |
combined_mse[0] = mse_embeddings_norm # Use normalized MSE values for facial
|
| 234 |
combined_mse[1] = mse_posture_norm # Use normalized MSE values for posture
|
| 235 |
|
|
|
|
| 236 |
cdict = {
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
}
|
| 244 |
|
| 245 |
custom_cmap = LinearSegmentedColormap('custom_cmap', segmentdata=cdict, N=256)
|
| 246 |
|
| 247 |
fig, ax = plt.subplots(figsize=(width/100, 2))
|
|
|
|
| 248 |
im = ax.imshow(combined_mse, aspect='auto', cmap=custom_cmap, extent=[0, total_frames, 0, 2])
|
| 249 |
ax.set_yticks([0.5, 1.5])
|
| 250 |
ax.set_yticklabels(['Face', 'Posture'])
|
|
@@ -280,4 +282,4 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
| 280 |
out.release()
|
| 281 |
plt.close(fig)
|
| 282 |
|
| 283 |
-
return output_path
|
|
|
|
| 233 |
combined_mse[0] = mse_embeddings_norm # Use normalized MSE values for facial
|
| 234 |
combined_mse[1] = mse_posture_norm # Use normalized MSE values for posture
|
| 235 |
|
| 236 |
+
# Custom colormap definition
|
| 237 |
cdict = {
|
| 238 |
+
'red': [(0.0, 0.0, 0.0), # Low MSE: No red
|
| 239 |
+
(1.0, 1.0, 1.0)], # High MSE: Full red
|
| 240 |
+
'green': [(0.0, 1.0, 1.0), # Low MSE: Full green
|
| 241 |
+
(1.0, 0.0, 0.0)], # High MSE: No green
|
| 242 |
+
'blue': [(0.0, 1.0, 1.0), # Low MSE: Full blue
|
| 243 |
+
(1.0, 0.0, 0.0)] # High MSE: No blue
|
| 244 |
+
}
|
| 245 |
|
| 246 |
custom_cmap = LinearSegmentedColormap('custom_cmap', segmentdata=cdict, N=256)
|
| 247 |
|
| 248 |
fig, ax = plt.subplots(figsize=(width/100, 2))
|
| 249 |
+
# Use the custom colormap in the heatmap generation
|
| 250 |
im = ax.imshow(combined_mse, aspect='auto', cmap=custom_cmap, extent=[0, total_frames, 0, 2])
|
| 251 |
ax.set_yticks([0.5, 1.5])
|
| 252 |
ax.set_yticklabels(['Face', 'Posture'])
|
|
|
|
| 282 |
out.release()
|
| 283 |
plt.close(fig)
|
| 284 |
|
| 285 |
+
return output_path
|