Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,893 Bytes
0bac694 dc08c30 cff3f6e 0bac694 dc08c30 0bac694 dc08c30 0bac694 dc08c30 0bac694 dc08c30 0bac694 dc08c30 0bac694 dc08c30 0bac694 dc08c30 0bac694 cff3f6e dc08c30 0bac694 dc08c30 cff3f6e dc08c30 0bac694 dc08c30 cff3f6e dc08c30 cff3f6e dc08c30 0bac694 cff3f6e 0bac694 dc08c30 cff3f6e dc08c30 cff3f6e dc08c30 0bac694 cff3f6e dc08c30 cff3f6e dc08c30 cff3f6e dc08c30 cff3f6e 0bac694 dc08c30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
import os
import librosa
import soundfile as sf
import numpy as np
import argparse
import logging
import gc
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def stft(wave, nfft, hl):
wave_left = np.ascontiguousarray(wave[0])
wave_right = np.ascontiguousarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
spec = np.stack([spec_left, spec_right])
return spec
def istft(spec, hl, length):
spec_left = np.ascontiguousarray(spec[0])
spec_right = np.ascontiguousarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
wave = np.stack([wave_left, wave_right])
return wave
def absmax(a, *, axis):
dims = list(a.shape)
dims.pop(axis)
indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
argmax = np.abs(a).argmax(axis=axis)
insert_pos = (len(a.shape) + axis) % len(a.shape)
indices.insert(insert_pos, argmax)
return a[tuple(indices)]
def absmin(a, *, axis):
dims = list(a.shape)
dims.pop(axis)
indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
argmax = np.abs(a).argmin(axis=axis)
insert_pos = (len(a.shape) + axis) % len(a.shape)
indices.insert(insert_pos, argmax)
return a[tuple(indices)]
def lambda_max(arr, axis=None, key=None, keepdims=False):
idxs = np.argmax(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]
def lambda_min(arr, axis=None, key=None, keepdims=False):
idxs = np.argmin(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]
def average_waveforms(pred_track, weights, algorithm):
"""
:param pred_track: shape = (num, channels, length)
:param weights: shape = (num, )
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
:return: averaged waveform in shape (channels, length)
"""
pred_track = np.asarray(pred_track) # NumPy 2.0+ compatibility
final_length = pred_track.shape[-1]
mod_track = []
for i in range(pred_track.shape[0]):
if algorithm == 'avg_wave':
mod_track.append(pred_track[i] * weights[i])
elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
mod_track.append(pred_track[i])
elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
spec = stft(pred_track[i], nfft=2048, hl=1024)
if algorithm == 'avg_fft':
mod_track.append(spec * weights[i])
else:
mod_track.append(spec)
del spec
gc.collect()
mod_track = np.asarray(mod_track) # NumPy 2.0+ compatibility
if algorithm == 'avg_wave':
result = mod_track.sum(axis=0) / np.sum(weights)
elif algorithm == 'median_wave':
result = np.median(mod_track, axis=0)
elif algorithm == 'min_wave':
result = lambda_min(mod_track, axis=0, key=np.abs)
elif algorithm == 'max_wave':
result = lambda_max(mod_track, axis=0, key=np.abs)
elif algorithm == 'avg_fft':
result = mod_track.sum(axis=0) / np.sum(weights)
result = istft(result, 1024, final_length)
elif algorithm == 'min_fft':
result = lambda_min(mod_track, axis=0, key=np.abs)
result = istft(result, 1024, final_length)
elif algorithm == 'max_fft':
result = absmax(mod_track, axis=0)
result = istft(result, 1024, final_length)
elif algorithm == 'median_fft':
result = np.median(mod_track, axis=0)
result = istft(result, 1024, final_length)
gc.collect()
return result
def ensemble_files(args):
parser = argparse.ArgumentParser(description="Ensemble audio files")
parser.add_argument('--files', nargs='+', required=True, help="Input audio files")
parser.add_argument('--type', required=True, choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], help="Ensemble type")
parser.add_argument('--weights', nargs='+', type=float, default=None, help="Weights for each file")
parser.add_argument('--output', required=True, help="Output file path")
args = parser.parse_args(args) if isinstance(args, list) else args
logger.info(f"Ensemble type: {args.type}")
logger.info(f"Number of input files: {len(args.files)}")
weights = args.weights if args.weights else [1.0] * len(args.files)
if len(weights) != len(args.files):
logger.error("Number of weights must match number of audio files")
raise ValueError("Number of weights must match number of audio files")
logger.info(f"Weights: {weights}")
logger.info(f"Output file: {args.output}")
data = []
sr = None
for f in args.files:
if not os.path.isfile(f):
logger.error(f"Cannot find file: {f}")
raise FileNotFoundError(f"Cannot find file: {f}")
logger.info(f"Reading file: {f}")
try:
wav, curr_sr = librosa.load(f, sr=None, mono=False)
if sr is None:
sr = curr_sr
elif sr != curr_sr:
logger.error("All audio files must have the same sample rate")
raise ValueError("All audio files must have the same sample rate")
logger.info(f"Waveform shape: {wav.shape} sample rate: {sr}")
data.append(wav)
del wav
gc.collect()
except Exception as e:
logger.error(f"Error reading audio file {f}: {str(e)}")
raise RuntimeError(f"Error reading audio file {f}: {str(e)}")
try:
data = np.asarray(data) # NumPy 2.0+ compatibility
res = average_waveforms(data, weights, args.type)
logger.info(f"Result shape: {res.shape}")
os.makedirs(os.path.dirname(args.output), exist_ok=True)
sf.write(args.output, res.T, sr, 'FLOAT')
logger.info(f"Output written to: {args.output}")
return args.output
except Exception as e:
logger.error(f"Error during ensemble processing: {str(e)}")
raise RuntimeError(f"Error during ensemble processing: {str(e)}")
finally:
gc.collect()
if __name__ == "__main__":
ensemble_files(sys.argv[1:]) |