Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import tempfile | |
import os | |
import torch | |
import librosa | |
from transformers import AutoModel | |
import net | |
import utils | |
# Global variables to store loaded models | |
ssl_model = None | |
pool_model = None | |
ser_model = None | |
wav_mean = None | |
wav_std = None | |
# Configuration | |
MODEL_PATH = "trained_models" | |
SSL_TYPE = utils.get_ssl_type("wavlm-large") | |
POOLING_TYPE = "AttentiveStatisticsPooling" | |
HEAD_DIM = 1024 | |
EMOTION_MAP = ['A', 'S', 'H', 'U', 'F', 'D', 'C', 'N'] | |
EMOTION_NAMES = { | |
'A': 'Angry', 'S': 'Sad', 'H': 'Happy', 'U': 'Surprise', | |
'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral' | |
} | |
from huggingface_hub import hf_hub_download | |
def load_models(): | |
global ssl_model, pool_model, ser_model, wav_mean, wav_std | |
if ssl_model is None: | |
print("Downloading and loading models from Hugging Face...") | |
# Paths to files in the repo | |
repo_id = "Samara369/SER_1" | |
ssl_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ssl.pt") | |
pool_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_pool.pt") | |
ser_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ser.pt") | |
norm_path = hf_hub_download(repo_id=repo_id, filename="trained_models/train_norm_stat.pkl") | |
# Load SSL model | |
ssl_model = AutoModel.from_pretrained("microsoft/wavlm-large") | |
ssl_model.load_state_dict(torch.load(ssl_path, map_location="cpu")) | |
ssl_model.eval() | |
# Load pooling model | |
feat_dim = ssl_model.config.hidden_size | |
pool_net = getattr(net, POOLING_TYPE) | |
pool_model = pool_net(feat_dim) | |
pool_model.load_state_dict(torch.load(pool_path, map_location="cpu")) | |
pool_model.eval() | |
# Load regression head | |
dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim | |
ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5) | |
ser_model.load_state_dict(torch.load(ser_path, map_location="cpu")) | |
ser_model.eval() | |
# Load normalization stats | |
wav_mean, wav_std = utils.load_norm_stat(norm_path) | |
print("Models loaded from Hugging Face.") | |
def process_single_audio(wav_path): | |
wav, _ = librosa.load(wav_path, sr=16000) | |
wav = torch.tensor(wav).float() | |
wav = (wav - wav_mean) / wav_std | |
wav = wav.unsqueeze(0) | |
mask = torch.ones_like(wav) | |
return wav, mask | |
def predict_emotion(audio_path): | |
try: | |
x, mask = process_single_audio(audio_path) | |
with torch.no_grad(): | |
ssl_out = ssl_model(x, attention_mask=mask).last_hidden_state | |
pooled = pool_model(ssl_out, mask) | |
logits = ser_model(pooled).cpu().numpy().flatten() | |
emotion_index = int(np.argmax(logits)) | |
predicted_emotion = EMOTION_MAP[emotion_index] | |
confidence = float(logits[emotion_index]) | |
return predicted_emotion, confidence, logits | |
except Exception as e: | |
print("Error during inference:", e) | |
return "Error", 0.0, None | |
def process_audio_file(audio_file): | |
if audio_file is None: | |
return "No audio file provided", 0.0, "" | |
load_models() | |
emotion_code, confidence, probabilities = predict_emotion(audio_file) | |
if emotion_code == "Error": | |
return "Error", 0.0, "β Error processing audio." | |
emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code) | |
result_text = f"π΅ **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n" | |
result_text += f"π― **Confidence:** {confidence:.4f}\n\n" | |
result_text += "**All Emotion Scores:**\n" | |
for i, score in enumerate(probabilities): | |
name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i]) | |
result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n" | |
return emotion_name, confidence, result_text | |
def process_live_audio(audio_data): | |
if audio_data is None: | |
return "No audio recorded", 0.0, "" | |
import soundfile as sf | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
temp_file.close() | |
sf.write(temp_file.name, audio_data[1], audio_data[0]) | |
load_models() | |
emotion_code, confidence, probabilities = predict_emotion(temp_file.name) | |
os.unlink(temp_file.name) | |
if emotion_code == "Error": | |
return "Error", 0.0, "β Error processing live audio." | |
emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code) | |
result_text = f"π€ **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n" | |
result_text += f"π― **Confidence:** {confidence:.4f}\n\n" | |
result_text += "**All Emotion Scores:**\n" | |
for i, score in enumerate(probabilities): | |
name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i]) | |
result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n" | |
return emotion_name, confidence, result_text | |
def create_ser_interface(): | |
with gr.Blocks(title="Speech Emotion Recognition") as interface: | |
gr.Markdown("## ποΈ Speech Emotion Recognition\nUpload or record audio to detect emotion.") | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.Audio(label="Upload Audio File", type="filepath") | |
file_btn = gr.Button("Analyze Uploaded Audio") | |
gr.Markdown("---") | |
live_audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Record Live Audio") | |
live_btn = gr.Button("Analyze Live Audio") | |
with gr.Column(): | |
emotion_output = gr.Textbox(label="Predicted Emotion", interactive=False) | |
confidence_output = gr.Slider(minimum=0, maximum=1, label="Confidence", interactive=False) | |
results_output = gr.Markdown(label="Results") | |
file_btn.click(process_audio_file, inputs=[file_input], outputs=[emotion_output, confidence_output, results_output]) | |
live_btn.click(process_live_audio, inputs=[live_audio], outputs=[emotion_output, confidence_output, results_output]) | |
return interface | |
if __name__ == "__main__": | |
interface = create_ser_interface() | |
interface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) | |