waidhoferj commited on
Commit
81caeb5
·
1 Parent(s): ba35f85

added weighted datasets

Browse files
Files changed (1) hide show
  1. preprocessing/dataset.py +9 -5
preprocessing/dataset.py CHANGED
@@ -78,7 +78,6 @@ class SongDataset(Dataset):
78
  return waveform, dance_labels
79
  else:
80
  # WARNING: Could cause train/test split leak
81
- print("Invalid output, trying next index...")
82
  return self[idx - 1]
83
 
84
  def _idx2audio_idx(self, idx: int) -> int:
@@ -112,8 +111,8 @@ class SongDataset(Dataset):
112
  is_finite = not torch.any(torch.isinf(x))
113
  is_numerical = not torch.any(torch.isnan(x))
114
  has_data = torch.any(x != 0.0)
115
- is_binary = len(torch.unique(y)) < 3
116
- return all((is_finite, is_numerical, has_data, is_binary))
117
 
118
  def _waveform_from_index(self, idx: int) -> torch.Tensor:
119
  audio_index, frame_index = self._get_audio_loc_from_idx(idx)
@@ -365,8 +364,13 @@ class DanceDataModule(pl.LightningDataModule):
365
  dataset = (
366
  self.dataset.dataset if isinstance(self.dataset, Subset) else self.dataset
367
  )
368
- weights = [ds.song_dataset.get_label_weights() for ds in dataset._data.datasets]
369
- return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
 
 
 
 
 
370
 
371
 
372
  def find_mean_std(dataset: Dataset, zscore=1.96, moe=0.02, p=0.5):
 
78
  return waveform, dance_labels
79
  else:
80
  # WARNING: Could cause train/test split leak
 
81
  return self[idx - 1]
82
 
83
  def _idx2audio_idx(self, idx: int) -> int:
 
111
  is_finite = not torch.any(torch.isinf(x))
112
  is_numerical = not torch.any(torch.isnan(x))
113
  has_data = torch.any(x != 0.0)
114
+ is_probability = torch.all(y >= -0.0001) and torch.all(y <= 1.0001)
115
+ return all((is_finite, is_numerical, has_data, is_probability))
116
 
117
  def _waveform_from_index(self, idx: int) -> torch.Tensor:
118
  audio_index, frame_index = self._get_audio_loc_from_idx(idx)
 
364
  dataset = (
365
  self.dataset.dataset if isinstance(self.dataset, Subset) else self.dataset
366
  )
367
+ total_len = len(dataset)
368
+ ds_weights = [len(ds) / total_len for ds in dataset._data.datasets]
369
+ weights = sum(
370
+ ds.song_dataset.get_label_weights() * w
371
+ for ds, w in zip(dataset._data.datasets, ds_weights)
372
+ )
373
+ return weights
374
 
375
 
376
  def find_mean_std(dataset: Dataset, zscore=1.96, moe=0.02, p=0.5):