File size: 6,893 Bytes
0bac694
 
 
 
 
 
 
 
dc08c30
cff3f6e
0bac694
dc08c30
 
 
0bac694
dc08c30
 
0bac694
 
dc08c30
0bac694
 
 
dc08c30
 
0bac694
 
dc08c30
0bac694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc08c30
0bac694
 
 
 
 
 
 
 
 
 
dc08c30
0bac694
 
 
cff3f6e
 
dc08c30
0bac694
dc08c30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cff3f6e
 
dc08c30
0bac694
 
dc08c30
 
 
 
 
cff3f6e
dc08c30
cff3f6e
dc08c30
 
 
 
 
 
 
 
0bac694
 
cff3f6e
0bac694
 
dc08c30
 
 
cff3f6e
 
 
 
 
dc08c30
 
 
cff3f6e
 
 
 
dc08c30
 
0bac694
cff3f6e
dc08c30
cff3f6e
dc08c30
 
 
 
 
cff3f6e
dc08c30
 
cff3f6e
 
0bac694
 
dc08c30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'

import os
import librosa
import soundfile as sf
import numpy as np
import argparse
import logging
import gc

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def stft(wave, nfft, hl):
    wave_left = np.ascontiguousarray(wave[0])
    wave_right = np.ascontiguousarray(wave[1])
    spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
    spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
    spec = np.stack([spec_left, spec_right])
    return spec

def istft(spec, hl, length):
    spec_left = np.ascontiguousarray(spec[0])
    spec_right = np.ascontiguousarray(spec[1])
    wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
    wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
    wave = np.stack([wave_left, wave_right])
    return wave

def absmax(a, *, axis):
    dims = list(a.shape)
    dims.pop(axis)
    indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
    argmax = np.abs(a).argmax(axis=axis)
    insert_pos = (len(a.shape) + axis) % len(a.shape)
    indices.insert(insert_pos, argmax)
    return a[tuple(indices)]

def absmin(a, *, axis):
    dims = list(a.shape)
    dims.pop(axis)
    indices = list(np.ogrid[tuple(slice(0, d) for d in dims)])
    argmax = np.abs(a).argmin(axis=axis)
    insert_pos = (len(a.shape) + axis) % len(a.shape)
    indices.insert(insert_pos, argmax)
    return a[tuple(indices)]

def lambda_max(arr, axis=None, key=None, keepdims=False):
    idxs = np.argmax(key(arr), axis)
    if axis is not None:
        idxs = np.expand_dims(idxs, axis)
        result = np.take_along_axis(arr, idxs, axis)
        if not keepdims:
            result = np.squeeze(result, axis=axis)
        return result
    else:
        return arr.flatten()[idxs]

def lambda_min(arr, axis=None, key=None, keepdims=False):
    idxs = np.argmin(key(arr), axis)
    if axis is not None:
        idxs = np.expand_dims(idxs, axis)
        result = np.take_along_axis(arr, idxs, axis)
        if not keepdims:
            result = np.squeeze(result, axis=axis)
        return result
    else:
        return arr.flatten()[idxs]

def average_waveforms(pred_track, weights, algorithm):
    """
    :param pred_track: shape = (num, channels, length)
    :param weights: shape = (num, )
    :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
    :return: averaged waveform in shape (channels, length)
    """
    pred_track = np.asarray(pred_track)  # NumPy 2.0+ compatibility
    final_length = pred_track.shape[-1]

    mod_track = []
    for i in range(pred_track.shape[0]):
        if algorithm == 'avg_wave':
            mod_track.append(pred_track[i] * weights[i])
        elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
            mod_track.append(pred_track[i])
        elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
            spec = stft(pred_track[i], nfft=2048, hl=1024)
            if algorithm == 'avg_fft':
                mod_track.append(spec * weights[i])
            else:
                mod_track.append(spec)
            del spec
            gc.collect()
    mod_track = np.asarray(mod_track)  # NumPy 2.0+ compatibility

    if algorithm == 'avg_wave':
        result = mod_track.sum(axis=0) / np.sum(weights)
    elif algorithm == 'median_wave':
        result = np.median(mod_track, axis=0)
    elif algorithm == 'min_wave':
        result = lambda_min(mod_track, axis=0, key=np.abs)
    elif algorithm == 'max_wave':
        result = lambda_max(mod_track, axis=0, key=np.abs)
    elif algorithm == 'avg_fft':
        result = mod_track.sum(axis=0) / np.sum(weights)
        result = istft(result, 1024, final_length)
    elif algorithm == 'min_fft':
        result = lambda_min(mod_track, axis=0, key=np.abs)
        result = istft(result, 1024, final_length)
    elif algorithm == 'max_fft':
        result = absmax(mod_track, axis=0)
        result = istft(result, 1024, final_length)
    elif algorithm == 'median_fft':
        result = np.median(mod_track, axis=0)
        result = istft(result, 1024, final_length)
    
    gc.collect()
    return result

def ensemble_files(args):
    parser = argparse.ArgumentParser(description="Ensemble audio files")
    parser.add_argument('--files', nargs='+', required=True, help="Input audio files")
    parser.add_argument('--type', required=True, choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], help="Ensemble type")
    parser.add_argument('--weights', nargs='+', type=float, default=None, help="Weights for each file")
    parser.add_argument('--output', required=True, help="Output file path")
    
    args = parser.parse_args(args) if isinstance(args, list) else args
    
    logger.info(f"Ensemble type: {args.type}")
    logger.info(f"Number of input files: {len(args.files)}")
    weights = args.weights if args.weights else [1.0] * len(args.files)
    if len(weights) != len(args.files):
        logger.error("Number of weights must match number of audio files")
        raise ValueError("Number of weights must match number of audio files")
    logger.info(f"Weights: {weights}")
    logger.info(f"Output file: {args.output}")
    
    data = []
    sr = None
    for f in args.files:
        if not os.path.isfile(f):
            logger.error(f"Cannot find file: {f}")
            raise FileNotFoundError(f"Cannot find file: {f}")
        logger.info(f"Reading file: {f}")
        try:
            wav, curr_sr = librosa.load(f, sr=None, mono=False)
            if sr is None:
                sr = curr_sr
            elif sr != curr_sr:
                logger.error("All audio files must have the same sample rate")
                raise ValueError("All audio files must have the same sample rate")
            logger.info(f"Waveform shape: {wav.shape} sample rate: {sr}")
            data.append(wav)
            del wav
            gc.collect()
        except Exception as e:
            logger.error(f"Error reading audio file {f}: {str(e)}")
            raise RuntimeError(f"Error reading audio file {f}: {str(e)}")
    
    try:
        data = np.asarray(data)  # NumPy 2.0+ compatibility
        res = average_waveforms(data, weights, args.type)
        logger.info(f"Result shape: {res.shape}")
        os.makedirs(os.path.dirname(args.output), exist_ok=True)
        sf.write(args.output, res.T, sr, 'FLOAT')
        logger.info(f"Output written to: {args.output}")
        return args.output
    except Exception as e:
        logger.error(f"Error during ensemble processing: {str(e)}")
        raise RuntimeError(f"Error during ensemble processing: {str(e)}")
    finally:
        gc.collect()

if __name__ == "__main__":
    ensemble_files(sys.argv[1:])