Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,8 +5,6 @@ import os
|
|
5 |
import torch
|
6 |
import librosa
|
7 |
from transformers import AutoModel
|
8 |
-
from huggingface_hub import hf_hub_download
|
9 |
-
|
10 |
import net
|
11 |
import utils
|
12 |
|
@@ -17,15 +15,8 @@ ser_model = None
|
|
17 |
wav_mean = None
|
18 |
wav_std = None
|
19 |
|
20 |
-
#
|
21 |
-
|
22 |
-
MODEL_FILES = {
|
23 |
-
"ssl": "trained_models/final_ssl.pt",
|
24 |
-
"pool": "trained_models/final_pool.pt",
|
25 |
-
"ser": "trained_models/final_ser.pt",
|
26 |
-
"stat": "trained_models/train_norm_stat.pkl"
|
27 |
-
}
|
28 |
-
|
29 |
SSL_TYPE = utils.get_ssl_type("wavlm-large")
|
30 |
POOLING_TYPE = "AttentiveStatisticsPooling"
|
31 |
HEAD_DIM = 1024
|
@@ -35,36 +26,43 @@ EMOTION_NAMES = {
|
|
35 |
'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral'
|
36 |
}
|
37 |
|
|
|
|
|
38 |
def load_models():
|
39 |
global ssl_model, pool_model, ser_model, wav_mean, wav_std
|
40 |
|
41 |
if ssl_model is None:
|
42 |
-
print("
|
43 |
-
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
ssl_model.
|
|
|
53 |
ssl_model.eval()
|
54 |
|
|
|
55 |
feat_dim = ssl_model.config.hidden_size
|
56 |
pool_net = getattr(net, POOLING_TYPE)
|
57 |
pool_model = pool_net(feat_dim)
|
58 |
-
pool_model.load_state_dict(torch.load(pool_path, map_location=
|
59 |
pool_model.eval()
|
60 |
|
|
|
61 |
dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim
|
62 |
ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5)
|
63 |
-
ser_model.load_state_dict(torch.load(ser_path, map_location=
|
64 |
ser_model.eval()
|
65 |
|
66 |
-
|
67 |
-
|
|
|
|
|
68 |
|
69 |
def process_single_audio(wav_path):
|
70 |
wav, _ = librosa.load(wav_path, sr=16000)
|
@@ -89,7 +87,7 @@ def predict_emotion(audio_path):
|
|
89 |
return predicted_emotion, confidence, logits
|
90 |
|
91 |
except Exception as e:
|
92 |
-
print("
|
93 |
return "Error", 0.0, None
|
94 |
|
95 |
def process_audio_file(audio_file):
|
|
|
5 |
import torch
|
6 |
import librosa
|
7 |
from transformers import AutoModel
|
|
|
|
|
8 |
import net
|
9 |
import utils
|
10 |
|
|
|
15 |
wav_mean = None
|
16 |
wav_std = None
|
17 |
|
18 |
+
# Configuration
|
19 |
+
MODEL_PATH = "trained_models"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
SSL_TYPE = utils.get_ssl_type("wavlm-large")
|
21 |
POOLING_TYPE = "AttentiveStatisticsPooling"
|
22 |
HEAD_DIM = 1024
|
|
|
26 |
'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral'
|
27 |
}
|
28 |
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
+
|
31 |
def load_models():
|
32 |
global ssl_model, pool_model, ser_model, wav_mean, wav_std
|
33 |
|
34 |
if ssl_model is None:
|
35 |
+
print("Downloading and loading models from Hugging Face...")
|
36 |
+
|
37 |
+
# Paths to files in the repo
|
38 |
+
repo_id = "Samara369/SER_1"
|
39 |
+
ssl_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ssl.pt")
|
40 |
+
pool_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_pool.pt")
|
41 |
+
ser_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ser.pt")
|
42 |
+
norm_path = hf_hub_download(repo_id=repo_id, filename="trained_models/train_norm_stat.pkl")
|
43 |
+
|
44 |
+
# Load SSL model
|
45 |
+
ssl_model = AutoModel.from_pretrained("microsoft/wavlm-large")
|
46 |
+
ssl_model.load_state_dict(torch.load(ssl_path, map_location="cpu"))
|
47 |
ssl_model.eval()
|
48 |
|
49 |
+
# Load pooling model
|
50 |
feat_dim = ssl_model.config.hidden_size
|
51 |
pool_net = getattr(net, POOLING_TYPE)
|
52 |
pool_model = pool_net(feat_dim)
|
53 |
+
pool_model.load_state_dict(torch.load(pool_path, map_location="cpu"))
|
54 |
pool_model.eval()
|
55 |
|
56 |
+
# Load regression head
|
57 |
dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim
|
58 |
ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5)
|
59 |
+
ser_model.load_state_dict(torch.load(ser_path, map_location="cpu"))
|
60 |
ser_model.eval()
|
61 |
|
62 |
+
# Load normalization stats
|
63 |
+
wav_mean, wav_std = utils.load_norm_stat(norm_path)
|
64 |
+
|
65 |
+
print("Models loaded from Hugging Face.")
|
66 |
|
67 |
def process_single_audio(wav_path):
|
68 |
wav, _ = librosa.load(wav_path, sr=16000)
|
|
|
87 |
return predicted_emotion, confidence, logits
|
88 |
|
89 |
except Exception as e:
|
90 |
+
print("Error during inference:", e)
|
91 |
return "Error", 0.0, None
|
92 |
|
93 |
def process_audio_file(audio_file):
|