Samara369's picture
Update app.py
f261a02 verified
import gradio as gr
import numpy as np
import tempfile
import os
import torch
import librosa
from transformers import AutoModel
import net
import utils
# Global variables to store loaded models
ssl_model = None
pool_model = None
ser_model = None
wav_mean = None
wav_std = None
# Configuration
MODEL_PATH = "trained_models"
SSL_TYPE = utils.get_ssl_type("wavlm-large")
POOLING_TYPE = "AttentiveStatisticsPooling"
HEAD_DIM = 1024
EMOTION_MAP = ['A', 'S', 'H', 'U', 'F', 'D', 'C', 'N']
EMOTION_NAMES = {
'A': 'Angry', 'S': 'Sad', 'H': 'Happy', 'U': 'Surprise',
'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral'
}
from huggingface_hub import hf_hub_download
def load_models():
global ssl_model, pool_model, ser_model, wav_mean, wav_std
if ssl_model is None:
print("Downloading and loading models from Hugging Face...")
# Paths to files in the repo
repo_id = "Samara369/SER_1"
ssl_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ssl.pt")
pool_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_pool.pt")
ser_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ser.pt")
norm_path = hf_hub_download(repo_id=repo_id, filename="trained_models/train_norm_stat.pkl")
# Load SSL model
ssl_model = AutoModel.from_pretrained("microsoft/wavlm-large")
ssl_model.load_state_dict(torch.load(ssl_path, map_location="cpu"))
ssl_model.eval()
# Load pooling model
feat_dim = ssl_model.config.hidden_size
pool_net = getattr(net, POOLING_TYPE)
pool_model = pool_net(feat_dim)
pool_model.load_state_dict(torch.load(pool_path, map_location="cpu"))
pool_model.eval()
# Load regression head
dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim
ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5)
ser_model.load_state_dict(torch.load(ser_path, map_location="cpu"))
ser_model.eval()
# Load normalization stats
wav_mean, wav_std = utils.load_norm_stat(norm_path)
print("Models loaded from Hugging Face.")
def process_single_audio(wav_path):
wav, _ = librosa.load(wav_path, sr=16000)
wav = torch.tensor(wav).float()
wav = (wav - wav_mean) / wav_std
wav = wav.unsqueeze(0)
mask = torch.ones_like(wav)
return wav, mask
def predict_emotion(audio_path):
try:
x, mask = process_single_audio(audio_path)
with torch.no_grad():
ssl_out = ssl_model(x, attention_mask=mask).last_hidden_state
pooled = pool_model(ssl_out, mask)
logits = ser_model(pooled).cpu().numpy().flatten()
emotion_index = int(np.argmax(logits))
predicted_emotion = EMOTION_MAP[emotion_index]
confidence = float(logits[emotion_index])
return predicted_emotion, confidence, logits
except Exception as e:
print("Error during inference:", e)
return "Error", 0.0, None
def process_audio_file(audio_file):
if audio_file is None:
return "No audio file provided", 0.0, ""
load_models()
emotion_code, confidence, probabilities = predict_emotion(audio_file)
if emotion_code == "Error":
return "Error", 0.0, "❌ Error processing audio."
emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code)
result_text = f"🎡 **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n"
result_text += f"🎯 **Confidence:** {confidence:.4f}\n\n"
result_text += "**All Emotion Scores:**\n"
for i, score in enumerate(probabilities):
name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i])
result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n"
return emotion_name, confidence, result_text
def process_live_audio(audio_data):
if audio_data is None:
return "No audio recorded", 0.0, ""
import soundfile as sf
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_file.close()
sf.write(temp_file.name, audio_data[1], audio_data[0])
load_models()
emotion_code, confidence, probabilities = predict_emotion(temp_file.name)
os.unlink(temp_file.name)
if emotion_code == "Error":
return "Error", 0.0, "❌ Error processing live audio."
emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code)
result_text = f"🎀 **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n"
result_text += f"🎯 **Confidence:** {confidence:.4f}\n\n"
result_text += "**All Emotion Scores:**\n"
for i, score in enumerate(probabilities):
name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i])
result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n"
return emotion_name, confidence, result_text
def create_ser_interface():
with gr.Blocks(title="Speech Emotion Recognition") as interface:
gr.Markdown("## πŸŽ™οΈ Speech Emotion Recognition\nUpload or record audio to detect emotion.")
with gr.Row():
with gr.Column():
file_input = gr.Audio(label="Upload Audio File", type="filepath")
file_btn = gr.Button("Analyze Uploaded Audio")
gr.Markdown("---")
live_audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Record Live Audio")
live_btn = gr.Button("Analyze Live Audio")
with gr.Column():
emotion_output = gr.Textbox(label="Predicted Emotion", interactive=False)
confidence_output = gr.Slider(minimum=0, maximum=1, label="Confidence", interactive=False)
results_output = gr.Markdown(label="Results")
file_btn.click(process_audio_file, inputs=[file_input], outputs=[emotion_output, confidence_output, results_output])
live_btn.click(process_live_audio, inputs=[live_audio], outputs=[emotion_output, confidence_output, results_output])
return interface
if __name__ == "__main__":
interface = create_ser_interface()
interface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)