ASesYusuf1 commited on
Commit
dc08c30
·
verified ·
1 Parent(s): efe5936

Update ensemble.py

Browse files
Files changed (1) hide show
  1. ensemble.py +65 -74
ensemble.py CHANGED
@@ -6,23 +6,26 @@ import librosa
6
  import soundfile as sf
7
  import numpy as np
8
  import argparse
9
- import uuid
10
  import gc
11
 
 
 
 
12
  def stft(wave, nfft, hl):
13
- wave_left = np.asfortranarray(wave[0])
14
- wave_right = np.asfortranarray(wave[1])
15
  spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
16
  spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
17
- spec = np.asfortranarray([spec_left, spec_right])
18
  return spec
19
 
20
  def istft(spec, hl, length):
21
- spec_left = np.asfortranarray(spec[0])
22
- spec_right = np.asfortranarray(spec[1])
23
  wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
24
  wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
25
- wave = np.asfortranarray([wave_left, wave_right])
26
  return wave
27
 
28
  def absmax(a, *, axis):
@@ -72,7 +75,7 @@ def average_waveforms(pred_track, weights, algorithm):
72
  :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
73
  :return: averaged waveform in shape (channels, length)
74
  """
75
- pred_track = np.array(pred_track, copy=False)
76
  final_length = pred_track.shape[-1]
77
 
78
  mod_track = []
@@ -83,103 +86,91 @@ def average_waveforms(pred_track, weights, algorithm):
83
  mod_track.append(pred_track[i])
84
  elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
85
  spec = stft(pred_track[i], nfft=2048, hl=1024)
86
- if algorithm in ['avg_fft']:
87
  mod_track.append(spec * weights[i])
88
  else:
89
  mod_track.append(spec)
90
  del spec
91
  gc.collect()
92
- pred_track = np.array(mod_track, copy=False)
93
 
94
- if algorithm in ['avg_wave']:
95
- pred_track = pred_track.sum(axis=0)
96
- pred_track /= np.array(weights).sum()
97
- elif algorithm in ['median_wave']:
98
- pred_track = np.median(pred_track, axis=0)
99
- elif algorithm in ['min_wave']:
100
- pred_track = lambda_min(pred_track, axis=0, key=np.abs)
101
- elif algorithm in ['max_wave']:
102
- pred_track = lambda_max(pred_track, axis=0, key=np.abs)
103
- elif algorithm in ['avg_fft']:
104
- pred_track = pred_track.sum(axis=0)
105
- pred_track /= np.array(weights).sum()
106
- pred_track = istft(pred_track, 1024, final_length)
107
- elif algorithm in ['min_fft']:
108
- pred_track = lambda_min(pred_track, axis=0, key=np.abs)
109
- pred_track = istft(pred_track, 1024, final_length)
110
- elif algorithm in ['max_fft']:
111
- pred_track = absmax(pred_track, axis=0)
112
- pred_track = istft(pred_track, 1024, final_length)
113
- elif algorithm in ['median_fft']:
114
- pred_track = np.median(pred_track, axis=0)
115
- pred_track = istft(pred_track, 1024, final_length)
116
 
117
  gc.collect()
118
- return pred_track
119
 
120
  def ensemble_files(args):
121
- parser = argparse.ArgumentParser()
122
- parser.add_argument("--files", type=str, required=True, nargs='+', help="Path to all audio-files to ensemble")
123
- parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
124
- parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
125
- parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
126
 
127
- try:
128
- args = parser.parse_args(args) if isinstance(args, list) else parser.parse_args()
129
- except SystemExit:
130
- print("Error: Invalid command-line arguments. Check --files, --type, --weights, and --output.")
131
- return None
132
-
133
- print('Ensemble type: {}'.format(args.type))
134
- print('Number of input files: {}'.format(len(args.files)))
135
- if args.weights is not None:
136
- weights = args.weights
137
- if len(weights) != len(args.files):
138
- print('Error: Number of weights must match number of audio files.')
139
- return None
140
- else:
141
- weights = np.ones(len(args.files))
142
- print('Weights: {}'.format(weights))
143
 
144
- # Validate output name
145
- if not args.output.endswith('.wav'):
146
- args.output += '.wav'
147
- output_path = os.path.join('/tmp', str(uuid.uuid4()) + '_' + args.output)
148
- print('Output file: {}'.format(output_path))
 
 
 
149
 
150
  data = []
151
  sr = None
152
  for f in args.files:
153
  if not os.path.isfile(f):
154
- print('Error. Can\'t find file: {}. Check paths.'.format(f))
155
- return None
156
- print('Reading file: {}'.format(f))
157
  try:
158
  wav, curr_sr = librosa.load(f, sr=None, mono=False)
159
  if sr is None:
160
  sr = curr_sr
161
  elif sr != curr_sr:
162
- print('Error: All audio files must have the same sample rate.')
163
- return None
164
- print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
165
  data.append(wav)
166
  del wav
167
  gc.collect()
168
  except Exception as e:
169
- print(f'Error reading audio file {f}: {str(e)}')
170
- return None
171
 
172
  try:
173
- data = np.array(data, copy=False)
174
  res = average_waveforms(data, weights, args.type)
175
- print('Result shape: {}'.format(res.shape))
176
- sf.write(output_path, res.T, sr, 'FLOAT')
177
- return output_path
 
 
178
  except Exception as e:
179
- print(f'Error during ensemble processing: {str(e)}')
180
- return None
181
  finally:
182
  gc.collect()
183
 
184
  if __name__ == "__main__":
185
- ensemble_files(None)
 
6
  import soundfile as sf
7
  import numpy as np
8
  import argparse
9
+ import logging
10
  import gc
11
 
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
  def stft(wave, nfft, hl):
16
+ wave_left = np.ascontiguousarray(wave[0])
17
+ wave_right = np.ascontiguousarray(wave[1])
18
  spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
19
  spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
20
+ spec = np.stack([spec_left, spec_right])
21
  return spec
22
 
23
  def istft(spec, hl, length):
24
+ spec_left = np.ascontiguousarray(spec[0])
25
+ spec_right = np.ascontiguousarray(spec[1])
26
  wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
27
  wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
28
+ wave = np.stack([wave_left, wave_right])
29
  return wave
30
 
31
  def absmax(a, *, axis):
 
75
  :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
76
  :return: averaged waveform in shape (channels, length)
77
  """
78
+ pred_track = np.asarray(pred_track) # NumPy 2.0+ compatibility
79
  final_length = pred_track.shape[-1]
80
 
81
  mod_track = []
 
86
  mod_track.append(pred_track[i])
87
  elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
88
  spec = stft(pred_track[i], nfft=2048, hl=1024)
89
+ if algorithm == 'avg_fft':
90
  mod_track.append(spec * weights[i])
91
  else:
92
  mod_track.append(spec)
93
  del spec
94
  gc.collect()
95
+ mod_track = np.asarray(mod_track) # NumPy 2.0+ compatibility
96
 
97
+ if algorithm == 'avg_wave':
98
+ result = mod_track.sum(axis=0) / np.sum(weights)
99
+ elif algorithm == 'median_wave':
100
+ result = np.median(mod_track, axis=0)
101
+ elif algorithm == 'min_wave':
102
+ result = lambda_min(mod_track, axis=0, key=np.abs)
103
+ elif algorithm == 'max_wave':
104
+ result = lambda_max(mod_track, axis=0, key=np.abs)
105
+ elif algorithm == 'avg_fft':
106
+ result = mod_track.sum(axis=0) / np.sum(weights)
107
+ result = istft(result, 1024, final_length)
108
+ elif algorithm == 'min_fft':
109
+ result = lambda_min(mod_track, axis=0, key=np.abs)
110
+ result = istft(result, 1024, final_length)
111
+ elif algorithm == 'max_fft':
112
+ result = absmax(mod_track, axis=0)
113
+ result = istft(result, 1024, final_length)
114
+ elif algorithm == 'median_fft':
115
+ result = np.median(mod_track, axis=0)
116
+ result = istft(result, 1024, final_length)
 
 
117
 
118
  gc.collect()
119
+ return result
120
 
121
  def ensemble_files(args):
122
+ parser = argparse.ArgumentParser(description="Ensemble audio files")
123
+ parser.add_argument('--files', nargs='+', required=True, help="Input audio files")
124
+ parser.add_argument('--type', required=True, choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], help="Ensemble type")
125
+ parser.add_argument('--weights', nargs='+', type=float, default=None, help="Weights for each file")
126
+ parser.add_argument('--output', required=True, help="Output file path")
127
 
128
+ args = parser.parse_args(args) if isinstance(args, list) else args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ logger.info(f"Ensemble type: {args.type}")
131
+ logger.info(f"Number of input files: {len(args.files)}")
132
+ weights = args.weights if args.weights else [1.0] * len(args.files)
133
+ if len(weights) != len(args.files):
134
+ logger.error("Number of weights must match number of audio files")
135
+ raise ValueError("Number of weights must match number of audio files")
136
+ logger.info(f"Weights: {weights}")
137
+ logger.info(f"Output file: {args.output}")
138
 
139
  data = []
140
  sr = None
141
  for f in args.files:
142
  if not os.path.isfile(f):
143
+ logger.error(f"Cannot find file: {f}")
144
+ raise FileNotFoundError(f"Cannot find file: {f}")
145
+ logger.info(f"Reading file: {f}")
146
  try:
147
  wav, curr_sr = librosa.load(f, sr=None, mono=False)
148
  if sr is None:
149
  sr = curr_sr
150
  elif sr != curr_sr:
151
+ logger.error("All audio files must have the same sample rate")
152
+ raise ValueError("All audio files must have the same sample rate")
153
+ logger.info(f"Waveform shape: {wav.shape} sample rate: {sr}")
154
  data.append(wav)
155
  del wav
156
  gc.collect()
157
  except Exception as e:
158
+ logger.error(f"Error reading audio file {f}: {str(e)}")
159
+ raise RuntimeError(f"Error reading audio file {f}: {str(e)}")
160
 
161
  try:
162
+ data = np.asarray(data) # NumPy 2.0+ compatibility
163
  res = average_waveforms(data, weights, args.type)
164
+ logger.info(f"Result shape: {res.shape}")
165
+ os.makedirs(os.path.dirname(args.output), exist_ok=True)
166
+ sf.write(args.output, res.T, sr, 'FLOAT')
167
+ logger.info(f"Output written to: {args.output}")
168
+ return args.output
169
  except Exception as e:
170
+ logger.error(f"Error during ensemble processing: {str(e)}")
171
+ raise RuntimeError(f"Error during ensemble processing: {str(e)}")
172
  finally:
173
  gc.collect()
174
 
175
  if __name__ == "__main__":
176
+ ensemble_files(sys.argv[1:])