Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +6 -2
visualization.py
CHANGED
|
@@ -153,9 +153,13 @@ def plot_mse_heatmap(mse_values, title, df):
|
|
| 153 |
sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax)
|
| 154 |
|
| 155 |
# Set x-axis ticks to timecodes
|
| 156 |
-
num_ticks = 60
|
| 157 |
tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int)
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
ax.set_xticks(tick_locations)
|
| 161 |
ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
|
|
|
|
| 153 |
sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax)
|
| 154 |
|
| 155 |
# Set x-axis ticks to timecodes
|
| 156 |
+
num_ticks = min(60, len(mse_values))
|
| 157 |
tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int)
|
| 158 |
+
|
| 159 |
+
# Ensure tick_locations are within bounds
|
| 160 |
+
tick_locations = tick_locations[tick_locations < len(df)]
|
| 161 |
+
|
| 162 |
+
tick_labels = [df['Timecode'].iloc[i] if i < len(df) else '' for i in tick_locations]
|
| 163 |
|
| 164 |
ax.set_xticks(tick_locations)
|
| 165 |
ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
|