Samara369 commited on
Commit
f261a02
Β·
verified Β·
1 Parent(s): 85b3a9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -27
app.py CHANGED
@@ -5,8 +5,6 @@ import os
5
  import torch
6
  import librosa
7
  from transformers import AutoModel
8
- from huggingface_hub import hf_hub_download
9
-
10
  import net
11
  import utils
12
 
@@ -17,15 +15,8 @@ ser_model = None
17
  wav_mean = None
18
  wav_std = None
19
 
20
- # HF Model repo path (you uploaded the model files here)
21
- HF_MODEL_REPO = "Samara369/SER_1"
22
- MODEL_FILES = {
23
- "ssl": "trained_models/final_ssl.pt",
24
- "pool": "trained_models/final_pool.pt",
25
- "ser": "trained_models/final_ser.pt",
26
- "stat": "trained_models/train_norm_stat.pkl"
27
- }
28
-
29
  SSL_TYPE = utils.get_ssl_type("wavlm-large")
30
  POOLING_TYPE = "AttentiveStatisticsPooling"
31
  HEAD_DIM = 1024
@@ -35,36 +26,43 @@ EMOTION_NAMES = {
35
  'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral'
36
  }
37
 
 
 
38
  def load_models():
39
  global ssl_model, pool_model, ser_model, wav_mean, wav_std
40
 
41
  if ssl_model is None:
42
- print("πŸ”„ Downloading and loading models on CPU...")
43
-
44
- # Download all model files using huggingface_hub
45
- ssl_path = hf_hub_download(HF_MODEL_REPO, MODEL_FILES["ssl"])
46
- pool_path = hf_hub_download(HF_MODEL_REPO, MODEL_FILES["pool"])
47
- ser_path = hf_hub_download(HF_MODEL_REPO, MODEL_FILES["ser"])
48
- stat_path = hf_hub_download(HF_MODEL_REPO, MODEL_FILES["stat"])
49
-
50
- ssl_model = AutoModel.from_pretrained(SSL_TYPE)
51
- ssl_model.freeze_feature_encoder()
52
- ssl_model.load_state_dict(torch.load(ssl_path, map_location='cpu'))
 
53
  ssl_model.eval()
54
 
 
55
  feat_dim = ssl_model.config.hidden_size
56
  pool_net = getattr(net, POOLING_TYPE)
57
  pool_model = pool_net(feat_dim)
58
- pool_model.load_state_dict(torch.load(pool_path, map_location='cpu'))
59
  pool_model.eval()
60
 
 
61
  dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim
62
  ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5)
63
- ser_model.load_state_dict(torch.load(ser_path, map_location='cpu'))
64
  ser_model.eval()
65
 
66
- wav_mean, wav_std = utils.load_norm_stat(stat_path)
67
- print("βœ… Models loaded.")
 
 
68
 
69
  def process_single_audio(wav_path):
70
  wav, _ = librosa.load(wav_path, sr=16000)
@@ -89,7 +87,7 @@ def predict_emotion(audio_path):
89
  return predicted_emotion, confidence, logits
90
 
91
  except Exception as e:
92
- print("❌ Error during inference:", e)
93
  return "Error", 0.0, None
94
 
95
  def process_audio_file(audio_file):
 
5
  import torch
6
  import librosa
7
  from transformers import AutoModel
 
 
8
  import net
9
  import utils
10
 
 
15
  wav_mean = None
16
  wav_std = None
17
 
18
+ # Configuration
19
+ MODEL_PATH = "trained_models"
 
 
 
 
 
 
 
20
  SSL_TYPE = utils.get_ssl_type("wavlm-large")
21
  POOLING_TYPE = "AttentiveStatisticsPooling"
22
  HEAD_DIM = 1024
 
26
  'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral'
27
  }
28
 
29
+ from huggingface_hub import hf_hub_download
30
+
31
  def load_models():
32
  global ssl_model, pool_model, ser_model, wav_mean, wav_std
33
 
34
  if ssl_model is None:
35
+ print("Downloading and loading models from Hugging Face...")
36
+
37
+ # Paths to files in the repo
38
+ repo_id = "Samara369/SER_1"
39
+ ssl_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ssl.pt")
40
+ pool_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_pool.pt")
41
+ ser_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ser.pt")
42
+ norm_path = hf_hub_download(repo_id=repo_id, filename="trained_models/train_norm_stat.pkl")
43
+
44
+ # Load SSL model
45
+ ssl_model = AutoModel.from_pretrained("microsoft/wavlm-large")
46
+ ssl_model.load_state_dict(torch.load(ssl_path, map_location="cpu"))
47
  ssl_model.eval()
48
 
49
+ # Load pooling model
50
  feat_dim = ssl_model.config.hidden_size
51
  pool_net = getattr(net, POOLING_TYPE)
52
  pool_model = pool_net(feat_dim)
53
+ pool_model.load_state_dict(torch.load(pool_path, map_location="cpu"))
54
  pool_model.eval()
55
 
56
+ # Load regression head
57
  dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim
58
  ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5)
59
+ ser_model.load_state_dict(torch.load(ser_path, map_location="cpu"))
60
  ser_model.eval()
61
 
62
+ # Load normalization stats
63
+ wav_mean, wav_std = utils.load_norm_stat(norm_path)
64
+
65
+ print("Models loaded from Hugging Face.")
66
 
67
  def process_single_audio(wav_path):
68
  wav, _ = librosa.load(wav_path, sr=16000)
 
87
  return predicted_emotion, confidence, logits
88
 
89
  except Exception as e:
90
+ print("Error during inference:", e)
91
  return "Error", 0.0, None
92
 
93
  def process_audio_file(audio_file):