File size: 10,671 Bytes
3978e51
24d9ef5
3978e51
 
24d9ef5
 
c1683bf
24d9ef5
c1683bf
3978e51
24d9ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1683bf
24d9ef5
c1683bf
24d9ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1683bf
24d9ef5
 
 
 
 
 
c1683bf
24d9ef5
 
 
c1683bf
24d9ef5
 
 
c1683bf
24d9ef5
 
c1683bf
24d9ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1683bf
24d9ef5
 
 
 
3978e51
24d9ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3978e51
24d9ef5
 
 
 
 
 
 
 
 
c1683bf
24d9ef5
 
c1683bf
24d9ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1683bf
24d9ef5
 
 
 
3978e51
24d9ef5
 
 
 
 
 
 
 
 
c1683bf
24d9ef5
c1683bf
24d9ef5
 
 
 
3978e51
 
24d9ef5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' # forked by jarredou

import os
import torch
import torchaudio
import soundfile as sf
import argparse
from pathlib import Path


def stft(wave, n_fft, hop_length):
    """
    Compute STFT for stereo audio using PyTorch
    :param wave: tensor of shape (channels, length)
    :param n_fft: FFT size
    :param hop_length: hop length
    :return: complex spectrogram of shape (channels, freq_bins, time_frames)
    """
    # Use torchaudio's STFT which handles multi-channel audio efficiently
    spec = torch.stft(
        wave,
        n_fft=n_fft,
        hop_length=hop_length,
        window=torch.hann_window(n_fft, device=wave.device),
        return_complex=True
    )
    return spec


def istft(spec, hop_length, length):
    """
    Compute inverse STFT for stereo audio using PyTorch
    :param spec: complex spectrogram of shape (channels, freq_bins, time_frames)
    :param hop_length: hop length
    :param length: target length
    :return: waveform of shape (channels, length)
    """
    # Use torchaudio's ISTFT which handles multi-channel audio efficiently
    wave = torch.istft(
        spec,
        n_fft=(spec.shape[1] - 1) * 2,
        hop_length=hop_length,
        window=torch.hann_window((spec.shape[1] - 1) * 2, device=spec.device),
        length=length
    )
    return wave


def absmax(a, *, dim):
    """Find values with maximum absolute value along dimension"""
    abs_a = torch.abs(a)
    indices = torch.argmax(abs_a, dim=dim, keepdim=True)
    return torch.gather(a, dim, indices).squeeze(dim)


def absmin(a, *, dim):
    """Find values with minimum absolute value along dimension"""
    abs_a = torch.abs(a)
    indices = torch.argmin(abs_a, dim=dim, keepdim=True)
    return torch.gather(a, dim, indices).squeeze(dim)


def lambda_max(arr, dim=None, key=None):
    """Find elements with maximum key value along dimension"""
    if key is None:
        key = lambda x: x
    key_values = key(arr)
    indices = torch.argmax(key_values, dim=dim, keepdim=True)
    return torch.gather(arr, dim, indices).squeeze(dim)


def lambda_min(arr, dim=None, key=None):
    """Find elements with minimum key value along dimension"""
    if key is None:
        key = lambda x: x
    key_values = key(arr)
    indices = torch.argmin(key_values, dim=dim, keepdim=True)
    return torch.gather(arr, dim, indices).squeeze(dim)


def match_tensor_shapes(tensor_1, tensor_2):
    """Match the time dimension of two tensors by padding or trimming"""
    target_length = tensor_2.shape[-1]
    current_length = tensor_1.shape[-1]
    
    if current_length > target_length:
        tensor_1 = tensor_1[..., :target_length]
    elif current_length < target_length:
        padding = target_length - current_length
        tensor_1 = torch.nn.functional.pad(tensor_1, (0, padding), 'constant', 0)
    
    return tensor_1


def average_waveforms(pred_track, weights, algorithm, n_fft, hop_length):
    """
    :param pred_track: tensor of shape (num, channels, length)
    :param weights: tensor of shape (num,)
    :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
    :param n_fft: FFT size for STFT operations
    :param hop_length: hop length for STFT operations
    :return: averaged waveform in shape (channels, length)
    """
    pred_track = torch.stack(pred_track)
    final_length = pred_track.shape[-1]
    weights = torch.tensor(weights, dtype=pred_track.dtype, device=pred_track.device)

    if algorithm in ['avg_wave', 'median_wave', 'min_wave', 'max_wave']:
        # Waveform domain operations
        if algorithm == 'avg_wave':
            # Weighted average
            weighted_tracks = pred_track * weights.view(-1, 1, 1)
            result = torch.sum(weighted_tracks, dim=0) / torch.sum(weights)
        elif algorithm == 'median_wave':
            result = torch.median(pred_track, dim=0)[0]
        elif algorithm == 'min_wave':
            result = lambda_min(pred_track, dim=0, key=torch.abs)
        elif algorithm == 'max_wave':
            result = lambda_max(pred_track, dim=0, key=torch.abs)
    
    elif algorithm in ['avg_fft', 'median_fft', 'min_fft', 'max_fft']:
        # Frequency domain operations
        # Convert all tracks to spectrograms
        spec_tracks = []
        for i in range(pred_track.shape[0]):
            spec = stft(pred_track[i], n_fft, hop_length)
            spec_tracks.append(spec)
        
        spec_tracks = torch.stack(spec_tracks)
        
        if algorithm == 'avg_fft':
            # Weighted average in frequency domain
            weighted_specs = spec_tracks * weights.view(-1, 1, 1, 1)
            avg_spec = torch.sum(weighted_specs, dim=0) / torch.sum(weights)
            result = istft(avg_spec, hop_length, final_length)
        elif algorithm == 'median_fft':
            # Median in frequency domain (using magnitude and phase separately)
            median_spec = torch.median(spec_tracks, dim=0)[0]
            result = istft(median_spec, hop_length, final_length)
        elif algorithm == 'min_fft':
            min_spec = lambda_min(spec_tracks, dim=0, key=torch.abs)
            result = istft(min_spec, hop_length, final_length)
        elif algorithm == 'max_fft':
            max_spec = absmax(spec_tracks, dim=0)
            result = istft(max_spec, hop_length, final_length)
    
    return result


def save_audio(waveform, sample_rate, output_path):
    """Save audio with support for different formats and bit depths using soundfile"""
    output_path = Path(output_path)
    
    # Ensure waveform is in the right format (channels, samples)
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    
    # Convert to CPU numpy array if on GPU
    if waveform.is_cuda:
        waveform = waveform.cpu()
    
    # Convert to numpy and transpose to (samples, channels) for soundfile
    audio_np = waveform.numpy().T
    
    # Handle different output formats
    if output_path.suffix.lower() == '.flac':
        # For FLAC, convert to 24-bit
        # Normalize to prevent clipping
        max_val = torch.max(torch.abs(waveform))
        if max_val > 0:
            print(f"Clipping detected : {max_val}")
            audio_np = audio_np / max_val.numpy()
        
        # Save as FLAC with 24-bit depth
        sf.write(str(output_path), audio_np, sample_rate, subtype='PCM_24')
    else:
        # Default to float32 for WAV and other formats
        sf.write(str(output_path), audio_np, sample_rate, subtype='FLOAT')


def ensemble_files(args):
    parser = argparse.ArgumentParser(description="Audio ensemble tool using PyTorch")
    parser.add_argument("--files", type=str, required=True, nargs='+', 
                       help="Path to all audio-files to ensemble")
    parser.add_argument("--type", type=str, default='avg_wave', 
                       help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
    parser.add_argument("--weights", type=float, nargs='+', 
                       help="Weights to create ensemble. Number of weights must be equal to number of files")
    parser.add_argument("--output", default="res.wav", type=str, 
                       help="Path to output file (supports .wav, .flac)")
    parser.add_argument("--n_fft", type=int, default=2048, 
                       help="FFT size for STFT operations (default: 2048)")
    parser.add_argument("--hop_length", type=int, default=1024, 
                       help="Hop length for STFT operations (default: 1024)")
    parser.add_argument("--device", type=str, default="auto", 
                       help="Device to use: 'cpu', 'cuda', or 'auto' (default: auto)")
    parser.add_argument("--trim_to_shortest", action="store_true", 
                       help="Trim output to shortest input file length to avoid padding artifacts")
    
    if args is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(args)

    # Set device
    if args.device == "auto":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args.device)
    
    print(f'Using device: {device}')
    print(f'Ensemble type: {args.type}')
    print(f'Number of input files: {len(args.files)}')
    print(f'N_FFT: {args.n_fft}, Hop length: {args.hop_length}')
    
    if args.weights is not None:
        if len(args.weights) != len(args.files):
            raise ValueError("Number of weights must match number of files")
        weights = args.weights
    else:
        weights = [1.0] * len(args.files)
    
    print(f'Weights: {weights}')
    print(f'Output file: {args.output}')

    data = []
    max_len = 0
    min_len = float('inf')
    sample_rate = None
    
    for f in args.files:
        if not os.path.isfile(f):
            print(f'Error. Can\'t find file: {f}. Check paths.')
            exit(1)
        
        print(f'Reading file: {f}')
        wav, sr = torchaudio.load(f)
        
        if sample_rate is None:
            sample_rate = sr
        elif sample_rate != sr:
            print(f'Warning: Sample rate mismatch. Expected {sample_rate}, got {sr}. Resampling...')
            resampler = torchaudio.transforms.Resample(sr, sample_rate)
            wav = resampler(wav)

        # Ensure stereo (2 channels)
        if wav.shape[0] == 1:
            print("Mono detected. Converting to stereo by duplication.")
            wav = wav.repeat(2, 1)
        elif wav.shape[0] > 2:
            print(f"Multi-channel audio detected ({wav.shape[0]} channels). Using first 2 channels.")
            wav = wav[:2]

        # Move to device
        wav = wav.to(device)
        
        print(f"Waveform shape: {wav.shape} sample rate: {sr}")
        data.append(wav)
        max_len = max(max_len, wav.shape[1])
        min_len = min(min_len, wav.shape[1])

    # Choose target length based on argument
    target_len = min_len if args.trim_to_shortest else max_len
    target_tensor = torch.zeros(2, target_len, device=device)
    data = [match_tensor_shapes(wav, target_tensor) for wav in data]

    print(f"Target length: {target_len} ({'shortest' if args.trim_to_shortest else 'longest'} input file)")

    print("Starting ensemble processing...")
    result = average_waveforms(data, weights, args.type, args.n_fft, args.hop_length)
    
    print(f'Result shape: {result.shape}')
    
    # Save the result
    save_audio(result, sample_rate, args.output)
    print(f'Ensemble saved to: {args.output}')


if __name__ == "__main__":
    ensemble_files(None)