uyiosa commited on
Commit
0290152
·
verified ·
1 Parent(s): 26ab833

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +204 -1
README.md CHANGED
@@ -17,4 +17,207 @@ tags:
17
  - speech-emotion-recognition
18
  ---
19
  The model is a recreation of [3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) for direct implementation in torch, with class definition and feed forward method. This model was recreated with the hopes of greater flexibilty of control, training/fine-tuning of model. The model was trained on the same [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) dataset as the original, but a different smaller subset was used. The subset is evenly distributed across gender and emotion category with hopes that training would improve accuracy of valence and arousal predictions.
20
- This model is therefore a multi-attributed based model which predict arousal, dominance and valence. However, unlike the original model, I just kept the original attribute score range of 0...7 (the range the dataset follows). I will provide the evaluations later on. For now I decided to make this repo so that other people could test out my model and see what they think of the inference accuracy themselves. My best trained weights s of now are provided in this repo. The class definition for the model is can be found in my [github](https://github.com/PhilipAmadasun/SER-Model-for-dimensional-attribute-prediction#).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  - speech-emotion-recognition
18
  ---
19
  The model is a recreation of [3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) for direct implementation in torch, with class definition and feed forward method. This model was recreated with the hopes of greater flexibilty of control, training/fine-tuning of model. The model was trained on the same [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) dataset as the original, but a different smaller subset was used. The subset is evenly distributed across gender and emotion category with hopes that training would improve accuracy of valence and arousal predictions.
20
+ This model is therefore a multi-attributed based model which predict arousal, dominance and valence. However, unlike the original model, I just kept the original attribute score range of 0...7 (the range the dataset follows). I will provide the evaluations later on. For now I decided to make this repo so that other people could test out my model and see what they think of the inference accuracy themselves, or retrain from scratch, modify etc. My best trained weights s of now are provided in this repo. The class definition for the model is can be found in my [github](https://github.com/PhilipAmadasun/SER-Model-for-dimensional-attribute-prediction#).
21
+
22
+ # Usage
23
+ ## Inference Testing
24
+ ```python
25
+ import torch
26
+ import torchaudio
27
+ from SER_Model_setup import SERModel
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ checkpoint_path = "<model.pt file>"
32
+ checkpoint = torch.load(checkpoint_path, map_location=device)
33
+
34
+ # Create the model architecture and load weights
35
+ model = SERModel()
36
+ model.load_state_dict(checkpoint['model_state_dict'])
37
+ model.to(device)
38
+ model.eval()
39
+
40
+ audio_path = "wav file"
41
+ audio, sr = torchaudio.load(audio_path)
42
+
43
+ if sr != model.sample_rate:
44
+ resampler = torchaudio.transforms.Resample(sr, model.sample_rate)
45
+ audio = resampler(audio)
46
+ #print(audio.shape[0])
47
+
48
+ if audio.shape[0] > 1:
49
+ audio = torch.mean(audio, dim=0, keepdim=True)
50
+
51
+ audio_len = audio.shape[-1]
52
+
53
+ # Create waveform tensor (shape: [1, audio_len])
54
+ waveform = torch.zeros(1, audio_len, dtype=torch.float32)
55
+ # print(waveform)
56
+ # print()
57
+ # print(f"waveform shape: {waveform.shape}")
58
+ # print()
59
+ waveform[0, :audio_len] = audio
60
+ # print(waveform)
61
+ # print()
62
+ # Create mask as 2D tensor: shape [1, audio_len] with ones in valid region
63
+ mask = torch.ones(1, audio_len, dtype=torch.float32)
64
+ # print(mask)
65
+ # print()
66
+ # print(f"mask shape: {mask.shape}")
67
+
68
+ # Move waveform and mask to device
69
+ waveform = waveform.to(device)
70
+ mask = mask.to(device)
71
+
72
+ # Normalize waveform using model's mean and std
73
+ mean = model.mean.to(device)
74
+ std = model.std.to(device)
75
+ waveform = (waveform - mean) / (std + 1e-6)
76
+
77
+ with torch.no_grad():
78
+ predictions = model(waveform, mask) # predictions shape: [1, 3]
79
+
80
+ # Extract predictions: [0,0] for arousal, [0,1] for valence, [0,2] for dominance
81
+ arousal = predictions[0, 0].item()
82
+ valence = predictions[0, 1].item()
83
+ dominance = predictions[0, 2].item()
84
+
85
+ print(f"Arousal: {arousal:.3f}")
86
+ print(f"Valence: {valence:.3f}")
87
+ print(f"Dominance: {dominance:.3f}")
88
+ ```
89
+ ## Batch inference
90
+ ```python
91
+ import os
92
+ import glob
93
+ import torch
94
+ import torchaudio
95
+ from SER_Model_setup import SERModel # Adjust if your model code is elsewhere
96
+
97
+ def load_model_from_checkpoint(checkpoint_path, device='cpu'):
98
+ """
99
+ Loads the SERModel and weights from a checkpoint, moves to device, sets eval mode.
100
+ """
101
+ checkpoint = torch.load(checkpoint_path, map_location=device)
102
+
103
+ # Create the model architecture
104
+ model = SERModel()
105
+ model.load_state_dict(checkpoint['model_state_dict'])
106
+
107
+ model.to(device)
108
+ model.eval()
109
+ return model
110
+
111
+ def batch_inference(model, file_paths, device='cpu', normalize=True):
112
+ """
113
+ Perform true batch inference on multiple .wav files in one forward pass.
114
+
115
+ Args:
116
+ model (SERModel): The loaded SER model in eval mode
117
+ file_paths (list[str]): List of paths to .wav files
118
+ device (str or torch.device): 'cpu' or 'cuda'
119
+ normalize (bool): Whether to normalize waveforms (subtract mean, divide std)
120
+
121
+ Returns:
122
+ dict: {filename: {"arousal": float, "valence": float, "dominance": float}}
123
+ """
124
+
125
+ # ----------------------------------------
126
+ # 1) Load & store all waveforms in memory
127
+ # ----------------------------------------
128
+ waveforms_list = []
129
+ lengths = []
130
+ for fp in file_paths:
131
+ # Load audio
132
+ audio, sr = torchaudio.load(fp)
133
+
134
+ # Resample if needed
135
+ if sr != model.sample_rate:
136
+ resampler = torchaudio.transforms.Resample(sr, model.sample_rate)
137
+ audio = resampler(audio)
138
+
139
+ # Convert stereo -> mono if needed
140
+ if audio.shape[0] > 1:
141
+ audio = torch.mean(audio, dim=0, keepdim=True)
142
+
143
+ # audio shape => [1, num_samples]
144
+ lengths.append(audio.shape[-1])
145
+ waveforms_list.append(audio)
146
+
147
+ # ----------------------------------------
148
+ # 2) Determine max length
149
+ # ----------------------------------------
150
+ max_len = max(lengths)
151
+
152
+ # ----------------------------------------
153
+ # 3) Pad each waveform to max length & build masks
154
+ # ----------------------------------------
155
+ batch_size = len(waveforms_list)
156
+ batched_waveforms = torch.zeros(batch_size, 1, max_len, dtype=torch.float32)
157
+ masks = torch.zeros(batch_size, max_len, dtype=torch.float32)
158
+
159
+ for i, audio in enumerate(waveforms_list):
160
+ cur_len = audio.shape[-1]
161
+ batched_waveforms[i, :, :cur_len] = audio
162
+ masks[i, :cur_len] = 1.0 # valid portion
163
+
164
+ # ----------------------------------------
165
+ # 4) Move batched data to device BEFORE normalization
166
+ # ----------------------------------------
167
+ batched_waveforms = batched_waveforms.to(device)
168
+ masks = masks.to(device)
169
+
170
+ # ----------------------------------------
171
+ # 5) Normalize if needed (model.mean, model.std)
172
+ # ----------------------------------------
173
+ if normalize:
174
+ # model.mean and model.std are buffers; ensure they're on the correct device
175
+ mean = model.mean.to(device)
176
+ std = model.std.to(device)
177
+ batched_waveforms = (batched_waveforms - mean) / (std + 1e-6)
178
+
179
+ # ----------------------------------------
180
+ # 6) Single forward pass
181
+ # ----------------------------------------
182
+ with torch.no_grad():
183
+ predictions = model(batched_waveforms, masks)
184
+ # predictions shape => [batch_size, 3]
185
+
186
+ # ----------------------------------------
187
+ # 7) Build result dict
188
+ # ----------------------------------------
189
+ results = {}
190
+ for i, fp in enumerate(file_paths):
191
+ arousal = predictions[i, 0].item()
192
+ valence = predictions[i, 1].item()
193
+ dominance = predictions[i, 2].item()
194
+ filename = os.path.basename(fp)
195
+ results[filename] = {
196
+ "arousal": arousal,
197
+ "valence": valence,
198
+ "dominance": dominance
199
+ }
200
+
201
+ return results
202
+
203
+ if __name__ == "__main__":
204
+ # -----------------------------------------
205
+ # Example usage
206
+ # -----------------------------------------
207
+ device = "cuda" if torch.cuda.is_available() else "cpu"
208
+
209
+ checkpoint_path = "<weights.pt>""
210
+ model = load_model_from_checkpoint(checkpoint_path, device=device)
211
+
212
+ # Suppose you have a folder of .wav files
213
+ wav_folder = "<directory containing .wav files>"
214
+ wav_paths = glob.glob(os.path.join(wav_folder, "*.wav"))
215
+
216
+ # Do a single pass of batch inference
217
+ all_results = batch_inference(model, wav_paths, device=device, normalize=True)
218
+
219
+ # Print results
220
+ for fname, preds in all_results.items():
221
+ print(f"{fname}: Arousal={preds['arousal']:.3f}, "
222
+ f"Valence={preds['valence']:.3f}, Dominance={preds['dominance']:.3f}")
223
+ ```