File size: 4,633 Bytes
34146f0
 
 
 
 
 
 
 
7c8310f
34146f0
 
 
 
f96cfa2
 
 
34146f0
f96cfa2
34146f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8a76c
f96cfa2
34146f0
 
 
f96cfa2
34146f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8a76c
 
f96cfa2
34146f0
 
 
f96cfa2
34146f0
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from torch.utils.data import Dataset
from transformers import AutoFeatureExtractor
import os
import librosa
import numpy as np

class DemoDataset(Dataset):
    def __init__(self, demonstration_paths, demonstration_labels, query_path, sample_rate=16000):
        self.sample_rate = sample_rate
        self.query_path = query_path
        
        # Convert to list if single path
        self.demonstration_paths = demonstration_paths
        self.demonstration_labels = [0 if label == 'bonafide' else 1 for label in demonstration_labels]

        # Load feature extractor
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
        

    def load_pad(self, path, max_length=64000):
        """Load and pad audio file"""
        X, sr = librosa.load(path, sr=self.sample_rate)
        X = self.pad(X, max_length)
        return X
    
    def pad(self, x, max_len=64000):
        """Pad audio to fixed length"""
        x_len = x.shape[0]
        if x_len >= max_len:
            return x[:max_len]
        pad_length = max_len - x_len
        return np.concatenate([x, np.zeros(pad_length)], axis=0)
    
    def __len__(self):
        return 1  # Only one query audio
    
    def __getitem__(self, idx):
        # Load query audio
        query_waveform = self.load_pad(self.query_path)
        query_waveform = torch.from_numpy(query_waveform).float()
        if len(query_waveform.shape) == 1:
            query_waveform = query_waveform.unsqueeze(0)
        
        # Extract features for query audio
        main_features = self.feature_extractor(
            query_waveform, 
            sampling_rate=self.sample_rate,
            padding=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        # Process demonstration audios
        prompt_features = []
        for demo_path in self.demonstration_paths:
            # Load demonstration audio
            demo_waveform = self.load_pad(demo_path)
            demo_waveform = torch.from_numpy(demo_waveform).float()
            if len(demo_waveform.shape) == 1:
                demo_waveform = demo_waveform.unsqueeze(0)
            
            # Extract features
            prompt_feature = self.feature_extractor(
                demo_waveform,
                sampling_rate=self.sample_rate,
                padding=True,
                return_attention_mask=True,
                return_tensors="pt"
            )
            prompt_features.append(prompt_feature)
        
        prompt_labels = torch.tensor([self.demonstration_labels], dtype=torch.long)
        
        return {
            'main_features': main_features,
            'prompt_features': prompt_features,
            'prompt_labels': prompt_labels,
            'file_name': os.path.basename(self.query_path),
            'file_path': self.query_path
        }

def collate_fn(batch):
    """
    Collate function for dataloader
    Args:
        batch: List containing dictionaries with:
            - main_features: feature extractor output
            - prompt_features: list of feature extractor outputs
            - file_name: file name
            - file_path: file path
    """
    batch_size = len(batch)
    
    # Process main features
    main_features_keys = batch[0]['main_features'].keys()
    main_features = {}
    for key in main_features_keys:
        main_features[key] = torch.cat([item['main_features'][key] for item in batch], dim=0)
    
    # Get number of prompts
    num_prompts = len(batch[0]['prompt_features'])
    
    # Process prompt features
    prompt_features = []
    for i in range(num_prompts):
        prompt_feature = {}
        for key in main_features_keys:
            prompt_feature[key] = torch.cat([item['prompt_features'][i][key] for item in batch], dim=0)
        prompt_features.append(prompt_feature)
    
    # Collect file names and paths
    file_names = [item['file_name'] for item in batch]
    file_paths = [item['file_path'] for item in batch]

    # 确保 prompt_labels 的形状正确 [batch_size, num_prompts]
    prompt_labels = torch.cat([item['prompt_labels'] for item in batch], dim=0)

    return {
        'main_features': main_features,
        'prompt_features': prompt_features,
        'prompt_labels': prompt_labels,
        'file_names': file_names,
        'file_paths': file_paths
    }

if __name__ == '__main__':
    # Test the dataset
    demo_paths = ["examples/demo1.wav", "examples/demo2.wav"]
    query_path = "examples/query.wav"
    
    dataset = DemoDataset(demo_paths, query_path)
    print(dataset[0])