Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,633 Bytes
34146f0 7c8310f 34146f0 f96cfa2 34146f0 f96cfa2 34146f0 6e8a76c f96cfa2 34146f0 f96cfa2 34146f0 6e8a76c f96cfa2 34146f0 f96cfa2 34146f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
from torch.utils.data import Dataset
from transformers import AutoFeatureExtractor
import os
import librosa
import numpy as np
class DemoDataset(Dataset):
def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000):
self.sample_rate = sample_rate
self.query_path = query_path
# Convert to list if single path
self.demonstration_paths = demonstration_paths
self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels]
# Load feature extractor
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
def load_pad(self, path, max_length=64000):
"""Load and pad audio file"""
X, sr = librosa.load(path, sr=self.sample_rate)
X = self.pad(X, max_length)
return X
def pad(self, x, max_len=64000):
"""Pad audio to fixed length"""
x_len = x.shape[0]
if x_len >= max_len:
return x[:max_len]
pad_length = max_len - x_len
return np.concatenate([x, np.zeros(pad_length)], axis=0)
def __len__(self):
return 1 # Only one query audio
def __getitem__(self, idx):
# Load query audio
query_waveform = self.load_pad(self.query_path)
query_waveform = torch.from_numpy(query_waveform).float()
if len(query_waveform.shape) == 1:
query_waveform = query_waveform.unsqueeze(0)
# Extract features for query audio
main_features = self.feature_extractor(
query_waveform,
sampling_rate=self.sample_rate,
padding=True,
return_attention_mask=True,
return_tensors="pt"
)
# Process demonstration audios
prompt_features = []
for demo_path in self.demonstration_paths:
# Load demonstration audio
demo_waveform = self.load_pad(demo_path)
demo_waveform = torch.from_numpy(demo_waveform).float()
if len(demo_waveform.shape) == 1:
demo_waveform = demo_waveform.unsqueeze(0)
# Extract features
prompt_feature = self.feature_extractor(
demo_waveform,
sampling_rate=self.sample_rate,
padding=True,
return_attention_mask=True,
return_tensors="pt"
)
prompt_features.append(prompt_feature)
prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long)
return {
'main_features': main_features,
'prompt_features': prompt_features,
'prompt_labels': prompt_labels,
'file_name': os.path.basename(self.query_path),
'file_path': self.query_path
}
def collate_fn(batch):
"""
Collate function for dataloader
Args:
batch: List containing dictionaries with:
- main_features: feature extractor output
- prompt_features: list of feature extractor outputs
- file_name: file name
- file_path: file path
"""
batch_size = len(batch)
# Process main features
main_features_keys = batch[0]['main_features'].keys()
main_features = {}
for key in main_features_keys:
main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0)
# Get number of prompts
num_prompts = len(batch[0]['prompt_features'])
# Process prompt features
prompt_features = []
for i in range(num_prompts):
prompt_feature = {}
for key in main_features_keys:
prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0)
prompt_features.append(prompt_feature)
# Collect file names and paths
file_names = [item['file_name'] for item in batch]
file_paths = [item['file_path'] for item in batch]
# 确保 prompt_labels 的形状正确 [batch_size, num_prompts]
prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0)
return {
'main_features': main_features,
'prompt_features': prompt_features,
'prompt_labels': prompt_labels,
'file_names': file_names,
'file_paths': file_paths
}
if __name__ == '__main__':
# Test the dataset
demo_paths = ["examples/demo1.wav", "examples/demo2.wav"]
query_path = "examples/query.wav"
dataset = DemoDataset(demo_paths, query_path)
print(dataset[0])
|