Samara369's picture
Update app.py
f261a02 verified
raw
history blame
6.22 kB
import gradio as gr
import numpy as np
import tempfile
import os
import torch
import librosa
from transformers import AutoModel
import net
import utils
# Global variables to store loaded models
ssl_model = None
pool_model = None
ser_model = None
wav_mean = None
wav_std = None
# Configuration
MODEL_PATH = "trained_models"
SSL_TYPE = utils.get_ssl_type("wavlm-large")
POOLING_TYPE = "AttentiveStatisticsPooling"
HEAD_DIM = 1024
EMOTION_MAP = ['A', 'S', 'H', 'U', 'F', 'D', 'C', 'N']
EMOTION_NAMES = {
'A': 'Angry', 'S': 'Sad', 'H': 'Happy', 'U': 'Surprise',
'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral'
}
from huggingface_hub import hf_hub_download
def load_models():
global ssl_model, pool_model, ser_model, wav_mean, wav_std
if ssl_model is None:
print("Downloading and loading models from Hugging Face...")
# Paths to files in the repo
repo_id = "Samara369/SER_1"
ssl_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ssl.pt")
pool_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_pool.pt")
ser_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ser.pt")
norm_path = hf_hub_download(repo_id=repo_id, filename="trained_models/train_norm_stat.pkl")
# Load SSL model
ssl_model = AutoModel.from_pretrained("microsoft/wavlm-large")
ssl_model.load_state_dict(torch.load(ssl_path, map_location="cpu"))
ssl_model.eval()
# Load pooling model
feat_dim = ssl_model.config.hidden_size
pool_net = getattr(net, POOLING_TYPE)
pool_model = pool_net(feat_dim)
pool_model.load_state_dict(torch.load(pool_path, map_location="cpu"))
pool_model.eval()
# Load regression head
dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim
ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5)
ser_model.load_state_dict(torch.load(ser_path, map_location="cpu"))
ser_model.eval()
# Load normalization stats
wav_mean, wav_std = utils.load_norm_stat(norm_path)
print("Models loaded from Hugging Face.")
def process_single_audio(wav_path):
wav, _ = librosa.load(wav_path, sr=16000)
wav = torch.tensor(wav).float()
wav = (wav - wav_mean) / wav_std
wav = wav.unsqueeze(0)
mask = torch.ones_like(wav)
return wav, mask
def predict_emotion(audio_path):
try:
x, mask = process_single_audio(audio_path)
with torch.no_grad():
ssl_out = ssl_model(x, attention_mask=mask).last_hidden_state
pooled = pool_model(ssl_out, mask)
logits = ser_model(pooled).cpu().numpy().flatten()
emotion_index = int(np.argmax(logits))
predicted_emotion = EMOTION_MAP[emotion_index]
confidence = float(logits[emotion_index])
return predicted_emotion, confidence, logits
except Exception as e:
print("Error during inference:", e)
return "Error", 0.0, None
def process_audio_file(audio_file):
if audio_file is None:
return "No audio file provided", 0.0, ""
load_models()
emotion_code, confidence, probabilities = predict_emotion(audio_file)
if emotion_code == "Error":
return "Error", 0.0, "❌ Error processing audio."
emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code)
result_text = f"🎡 **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n"
result_text += f"🎯 **Confidence:** {confidence:.4f}\n\n"
result_text += "**All Emotion Scores:**\n"
for i, score in enumerate(probabilities):
name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i])
result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n"
return emotion_name, confidence, result_text
def process_live_audio(audio_data):
if audio_data is None:
return "No audio recorded", 0.0, ""
import soundfile as sf
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_file.close()
sf.write(temp_file.name, audio_data[1], audio_data[0])
load_models()
emotion_code, confidence, probabilities = predict_emotion(temp_file.name)
os.unlink(temp_file.name)
if emotion_code == "Error":
return "Error", 0.0, "❌ Error processing live audio."
emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code)
result_text = f"🎀 **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n"
result_text += f"🎯 **Confidence:** {confidence:.4f}\n\n"
result_text += "**All Emotion Scores:**\n"
for i, score in enumerate(probabilities):
name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i])
result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n"
return emotion_name, confidence, result_text
def create_ser_interface():
with gr.Blocks(title="Speech Emotion Recognition") as interface:
gr.Markdown("## πŸŽ™οΈ Speech Emotion Recognition\nUpload or record audio to detect emotion.")
with gr.Row():
with gr.Column():
file_input = gr.Audio(label="Upload Audio File", type="filepath")
file_btn = gr.Button("Analyze Uploaded Audio")
gr.Markdown("---")
live_audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Record Live Audio")
live_btn = gr.Button("Analyze Live Audio")
with gr.Column():
emotion_output = gr.Textbox(label="Predicted Emotion", interactive=False)
confidence_output = gr.Slider(minimum=0, maximum=1, label="Confidence", interactive=False)
results_output = gr.Markdown(label="Results")
file_btn.click(process_audio_file, inputs=[file_input], outputs=[emotion_output, confidence_output, results_output])
live_btn.click(process_live_audio, inputs=[live_audio], outputs=[emotion_output, confidence_output, results_output])
return interface
if __name__ == "__main__":
interface = create_ser_interface()
interface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)