Update README.md
Browse files
README.md
CHANGED
@@ -17,4 +17,207 @@ tags:
|
|
17 |
- speech-emotion-recognition
|
18 |
---
|
19 |
The model is a recreation of [3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) for direct implementation in torch, with class definition and feed forward method. This model was recreated with the hopes of greater flexibilty of control, training/fine-tuning of model. The model was trained on the same [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) dataset as the original, but a different smaller subset was used. The subset is evenly distributed across gender and emotion category with hopes that training would improve accuracy of valence and arousal predictions.
|
20 |
-
This model is therefore a multi-attributed based model which predict arousal, dominance and valence. However, unlike the original model, I just kept the original attribute score range of 0...7 (the range the dataset follows). I will provide the evaluations later on. For now I decided to make this repo so that other people could test out my model and see what they think of the inference accuracy themselves. My best trained weights s of now are provided in this repo. The class definition for the model is can be found in my [github](https://github.com/PhilipAmadasun/SER-Model-for-dimensional-attribute-prediction#).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
- speech-emotion-recognition
|
18 |
---
|
19 |
The model is a recreation of [3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) for direct implementation in torch, with class definition and feed forward method. This model was recreated with the hopes of greater flexibilty of control, training/fine-tuning of model. The model was trained on the same [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) dataset as the original, but a different smaller subset was used. The subset is evenly distributed across gender and emotion category with hopes that training would improve accuracy of valence and arousal predictions.
|
20 |
+
This model is therefore a multi-attributed based model which predict arousal, dominance and valence. However, unlike the original model, I just kept the original attribute score range of 0...7 (the range the dataset follows). I will provide the evaluations later on. For now I decided to make this repo so that other people could test out my model and see what they think of the inference accuracy themselves, or retrain from scratch, modify etc. My best trained weights s of now are provided in this repo. The class definition for the model is can be found in my [github](https://github.com/PhilipAmadasun/SER-Model-for-dimensional-attribute-prediction#).
|
21 |
+
|
22 |
+
# Usage
|
23 |
+
## Inference Testing
|
24 |
+
```python
|
25 |
+
import torch
|
26 |
+
import torchaudio
|
27 |
+
from SER_Model_setup import SERModel
|
28 |
+
|
29 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
+
|
31 |
+
checkpoint_path = "<model.pt file>"
|
32 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
33 |
+
|
34 |
+
# Create the model architecture and load weights
|
35 |
+
model = SERModel()
|
36 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
37 |
+
model.to(device)
|
38 |
+
model.eval()
|
39 |
+
|
40 |
+
audio_path = "wav file"
|
41 |
+
audio, sr = torchaudio.load(audio_path)
|
42 |
+
|
43 |
+
if sr != model.sample_rate:
|
44 |
+
resampler = torchaudio.transforms.Resample(sr, model.sample_rate)
|
45 |
+
audio = resampler(audio)
|
46 |
+
#print(audio.shape[0])
|
47 |
+
|
48 |
+
if audio.shape[0] > 1:
|
49 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
50 |
+
|
51 |
+
audio_len = audio.shape[-1]
|
52 |
+
|
53 |
+
# Create waveform tensor (shape: [1, audio_len])
|
54 |
+
waveform = torch.zeros(1, audio_len, dtype=torch.float32)
|
55 |
+
# print(waveform)
|
56 |
+
# print()
|
57 |
+
# print(f"waveform shape: {waveform.shape}")
|
58 |
+
# print()
|
59 |
+
waveform[0, :audio_len] = audio
|
60 |
+
# print(waveform)
|
61 |
+
# print()
|
62 |
+
# Create mask as 2D tensor: shape [1, audio_len] with ones in valid region
|
63 |
+
mask = torch.ones(1, audio_len, dtype=torch.float32)
|
64 |
+
# print(mask)
|
65 |
+
# print()
|
66 |
+
# print(f"mask shape: {mask.shape}")
|
67 |
+
|
68 |
+
# Move waveform and mask to device
|
69 |
+
waveform = waveform.to(device)
|
70 |
+
mask = mask.to(device)
|
71 |
+
|
72 |
+
# Normalize waveform using model's mean and std
|
73 |
+
mean = model.mean.to(device)
|
74 |
+
std = model.std.to(device)
|
75 |
+
waveform = (waveform - mean) / (std + 1e-6)
|
76 |
+
|
77 |
+
with torch.no_grad():
|
78 |
+
predictions = model(waveform, mask) # predictions shape: [1, 3]
|
79 |
+
|
80 |
+
# Extract predictions: [0,0] for arousal, [0,1] for valence, [0,2] for dominance
|
81 |
+
arousal = predictions[0, 0].item()
|
82 |
+
valence = predictions[0, 1].item()
|
83 |
+
dominance = predictions[0, 2].item()
|
84 |
+
|
85 |
+
print(f"Arousal: {arousal:.3f}")
|
86 |
+
print(f"Valence: {valence:.3f}")
|
87 |
+
print(f"Dominance: {dominance:.3f}")
|
88 |
+
```
|
89 |
+
## Batch inference
|
90 |
+
```python
|
91 |
+
import os
|
92 |
+
import glob
|
93 |
+
import torch
|
94 |
+
import torchaudio
|
95 |
+
from SER_Model_setup import SERModel # Adjust if your model code is elsewhere
|
96 |
+
|
97 |
+
def load_model_from_checkpoint(checkpoint_path, device='cpu'):
|
98 |
+
"""
|
99 |
+
Loads the SERModel and weights from a checkpoint, moves to device, sets eval mode.
|
100 |
+
"""
|
101 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
102 |
+
|
103 |
+
# Create the model architecture
|
104 |
+
model = SERModel()
|
105 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
106 |
+
|
107 |
+
model.to(device)
|
108 |
+
model.eval()
|
109 |
+
return model
|
110 |
+
|
111 |
+
def batch_inference(model, file_paths, device='cpu', normalize=True):
|
112 |
+
"""
|
113 |
+
Perform true batch inference on multiple .wav files in one forward pass.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
model (SERModel): The loaded SER model in eval mode
|
117 |
+
file_paths (list[str]): List of paths to .wav files
|
118 |
+
device (str or torch.device): 'cpu' or 'cuda'
|
119 |
+
normalize (bool): Whether to normalize waveforms (subtract mean, divide std)
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
dict: {filename: {"arousal": float, "valence": float, "dominance": float}}
|
123 |
+
"""
|
124 |
+
|
125 |
+
# ----------------------------------------
|
126 |
+
# 1) Load & store all waveforms in memory
|
127 |
+
# ----------------------------------------
|
128 |
+
waveforms_list = []
|
129 |
+
lengths = []
|
130 |
+
for fp in file_paths:
|
131 |
+
# Load audio
|
132 |
+
audio, sr = torchaudio.load(fp)
|
133 |
+
|
134 |
+
# Resample if needed
|
135 |
+
if sr != model.sample_rate:
|
136 |
+
resampler = torchaudio.transforms.Resample(sr, model.sample_rate)
|
137 |
+
audio = resampler(audio)
|
138 |
+
|
139 |
+
# Convert stereo -> mono if needed
|
140 |
+
if audio.shape[0] > 1:
|
141 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
142 |
+
|
143 |
+
# audio shape => [1, num_samples]
|
144 |
+
lengths.append(audio.shape[-1])
|
145 |
+
waveforms_list.append(audio)
|
146 |
+
|
147 |
+
# ----------------------------------------
|
148 |
+
# 2) Determine max length
|
149 |
+
# ----------------------------------------
|
150 |
+
max_len = max(lengths)
|
151 |
+
|
152 |
+
# ----------------------------------------
|
153 |
+
# 3) Pad each waveform to max length & build masks
|
154 |
+
# ----------------------------------------
|
155 |
+
batch_size = len(waveforms_list)
|
156 |
+
batched_waveforms = torch.zeros(batch_size, 1, max_len, dtype=torch.float32)
|
157 |
+
masks = torch.zeros(batch_size, max_len, dtype=torch.float32)
|
158 |
+
|
159 |
+
for i, audio in enumerate(waveforms_list):
|
160 |
+
cur_len = audio.shape[-1]
|
161 |
+
batched_waveforms[i, :, :cur_len] = audio
|
162 |
+
masks[i, :cur_len] = 1.0 # valid portion
|
163 |
+
|
164 |
+
# ----------------------------------------
|
165 |
+
# 4) Move batched data to device BEFORE normalization
|
166 |
+
# ----------------------------------------
|
167 |
+
batched_waveforms = batched_waveforms.to(device)
|
168 |
+
masks = masks.to(device)
|
169 |
+
|
170 |
+
# ----------------------------------------
|
171 |
+
# 5) Normalize if needed (model.mean, model.std)
|
172 |
+
# ----------------------------------------
|
173 |
+
if normalize:
|
174 |
+
# model.mean and model.std are buffers; ensure they're on the correct device
|
175 |
+
mean = model.mean.to(device)
|
176 |
+
std = model.std.to(device)
|
177 |
+
batched_waveforms = (batched_waveforms - mean) / (std + 1e-6)
|
178 |
+
|
179 |
+
# ----------------------------------------
|
180 |
+
# 6) Single forward pass
|
181 |
+
# ----------------------------------------
|
182 |
+
with torch.no_grad():
|
183 |
+
predictions = model(batched_waveforms, masks)
|
184 |
+
# predictions shape => [batch_size, 3]
|
185 |
+
|
186 |
+
# ----------------------------------------
|
187 |
+
# 7) Build result dict
|
188 |
+
# ----------------------------------------
|
189 |
+
results = {}
|
190 |
+
for i, fp in enumerate(file_paths):
|
191 |
+
arousal = predictions[i, 0].item()
|
192 |
+
valence = predictions[i, 1].item()
|
193 |
+
dominance = predictions[i, 2].item()
|
194 |
+
filename = os.path.basename(fp)
|
195 |
+
results[filename] = {
|
196 |
+
"arousal": arousal,
|
197 |
+
"valence": valence,
|
198 |
+
"dominance": dominance
|
199 |
+
}
|
200 |
+
|
201 |
+
return results
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
# -----------------------------------------
|
205 |
+
# Example usage
|
206 |
+
# -----------------------------------------
|
207 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
208 |
+
|
209 |
+
checkpoint_path = "<weights.pt>""
|
210 |
+
model = load_model_from_checkpoint(checkpoint_path, device=device)
|
211 |
+
|
212 |
+
# Suppose you have a folder of .wav files
|
213 |
+
wav_folder = "<directory containing .wav files>"
|
214 |
+
wav_paths = glob.glob(os.path.join(wav_folder, "*.wav"))
|
215 |
+
|
216 |
+
# Do a single pass of batch inference
|
217 |
+
all_results = batch_inference(model, wav_paths, device=device, normalize=True)
|
218 |
+
|
219 |
+
# Print results
|
220 |
+
for fname, preds in all_results.items():
|
221 |
+
print(f"{fname}: Arousal={preds['arousal']:.3f}, "
|
222 |
+
f"Valence={preds['valence']:.3f}, Dominance={preds['dominance']:.3f}")
|
223 |
+
```
|