Update README.md
Browse files
README.md
CHANGED
@@ -36,4 +36,55 @@ The included emotions are:
|
|
36 |
</pre>
|
37 |
|
38 |
- Library: https://github.com/tiantiaf0627/vox-profile-release
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
</pre>
|
37 |
|
38 |
- Library: https://github.com/tiantiaf0627/vox-profile-release
|
39 |
+
|
40 |
+
# How to use this model
|
41 |
+
|
42 |
+
## Download repo
|
43 |
+
```
|
44 |
+
git clone [email protected]:tiantiaf0627/vox-profile-release.git
|
45 |
+
```
|
46 |
+
## Install the package
|
47 |
+
```
|
48 |
+
conda create -n vox_profile python=3.8
|
49 |
+
cd vox-profile-release
|
50 |
+
pip install -e .
|
51 |
+
```
|
52 |
+
|
53 |
+
## Load the model
|
54 |
+
```
|
55 |
+
# Load libraries
|
56 |
+
import torch
|
57 |
+
import torch.nn.functional as F
|
58 |
+
from src.model.emotion.whisper_emotion import WhisperWrapper
|
59 |
+
# Find device
|
60 |
+
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
61 |
+
# Load model from Huggingface
|
62 |
+
model = WhisperWrapper.from_pretrained("tiantiaf/whisper-large-v3-msp-podcast-emotion").to(device)
|
63 |
+
model.eval()
|
64 |
+
```
|
65 |
+
|
66 |
+
## Prediction
|
67 |
+
```
|
68 |
+
# Label List
|
69 |
+
emotion_label_list = [
|
70 |
+
'Anger',
|
71 |
+
'Contempt',
|
72 |
+
'Disgust',
|
73 |
+
'Fear',
|
74 |
+
'Happiness',
|
75 |
+
'Neutral',
|
76 |
+
'Sadness',
|
77 |
+
'Surprise',
|
78 |
+
'Other'
|
79 |
+
]
|
80 |
+
|
81 |
+
# Load data, here just zeros as the example, audio data should be 16kHz mono channel
|
82 |
+
data = torch.zeros([1, 16000]).float().to(device)
|
83 |
+
logits, embedding, _, _, _, _ = wavlm_model(
|
84 |
+
data, return_feature=True
|
85 |
+
)
|
86 |
+
|
87 |
+
# Probability and output
|
88 |
+
emotion_prob = F.softmax(logits, dim=1)
|
89 |
+
print(emotion_label_list[torch.argmax(emotion_prob).detach().cpu().item()])
|
90 |
+
```
|