Spaces:
Runtime error
Runtime error
Commit
·
81caeb5
1
Parent(s):
ba35f85
added weighted datasets
Browse files- preprocessing/dataset.py +9 -5
preprocessing/dataset.py
CHANGED
@@ -78,7 +78,6 @@ class SongDataset(Dataset):
|
|
78 |
return waveform, dance_labels
|
79 |
else:
|
80 |
# WARNING: Could cause train/test split leak
|
81 |
-
print("Invalid output, trying next index...")
|
82 |
return self[idx - 1]
|
83 |
|
84 |
def _idx2audio_idx(self, idx: int) -> int:
|
@@ -112,8 +111,8 @@ class SongDataset(Dataset):
|
|
112 |
is_finite = not torch.any(torch.isinf(x))
|
113 |
is_numerical = not torch.any(torch.isnan(x))
|
114 |
has_data = torch.any(x != 0.0)
|
115 |
-
|
116 |
-
return all((is_finite, is_numerical, has_data,
|
117 |
|
118 |
def _waveform_from_index(self, idx: int) -> torch.Tensor:
|
119 |
audio_index, frame_index = self._get_audio_loc_from_idx(idx)
|
@@ -365,8 +364,13 @@ class DanceDataModule(pl.LightningDataModule):
|
|
365 |
dataset = (
|
366 |
self.dataset.dataset if isinstance(self.dataset, Subset) else self.dataset
|
367 |
)
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
|
372 |
def find_mean_std(dataset: Dataset, zscore=1.96, moe=0.02, p=0.5):
|
|
|
78 |
return waveform, dance_labels
|
79 |
else:
|
80 |
# WARNING: Could cause train/test split leak
|
|
|
81 |
return self[idx - 1]
|
82 |
|
83 |
def _idx2audio_idx(self, idx: int) -> int:
|
|
|
111 |
is_finite = not torch.any(torch.isinf(x))
|
112 |
is_numerical = not torch.any(torch.isnan(x))
|
113 |
has_data = torch.any(x != 0.0)
|
114 |
+
is_probability = torch.all(y >= -0.0001) and torch.all(y <= 1.0001)
|
115 |
+
return all((is_finite, is_numerical, has_data, is_probability))
|
116 |
|
117 |
def _waveform_from_index(self, idx: int) -> torch.Tensor:
|
118 |
audio_index, frame_index = self._get_audio_loc_from_idx(idx)
|
|
|
364 |
dataset = (
|
365 |
self.dataset.dataset if isinstance(self.dataset, Subset) else self.dataset
|
366 |
)
|
367 |
+
total_len = len(dataset)
|
368 |
+
ds_weights = [len(ds) / total_len for ds in dataset._data.datasets]
|
369 |
+
weights = sum(
|
370 |
+
ds.song_dataset.get_label_weights() * w
|
371 |
+
for ds, w in zip(dataset._data.datasets, ds_weights)
|
372 |
+
)
|
373 |
+
return weights
|
374 |
|
375 |
|
376 |
def find_mean_std(dataset: Dataset, zscore=1.96, moe=0.02, p=0.5):
|