import gradio as gr import numpy as np import tempfile import os import torch import librosa from transformers import AutoModel import net import utils # Global variables to store loaded models ssl_model = None pool_model = None ser_model = None wav_mean = None wav_std = None # Configuration MODEL_PATH = "trained_models" SSL_TYPE = utils.get_ssl_type("wavlm-large") POOLING_TYPE = "AttentiveStatisticsPooling" HEAD_DIM = 1024 EMOTION_MAP = ['A', 'S', 'H', 'U', 'F', 'D', 'C', 'N'] EMOTION_NAMES = { 'A': 'Angry', 'S': 'Sad', 'H': 'Happy', 'U': 'Surprise', 'F': 'Fear', 'D': 'Disgust', 'C': 'Contempt', 'N': 'Neutral' } from huggingface_hub import hf_hub_download def load_models(): global ssl_model, pool_model, ser_model, wav_mean, wav_std if ssl_model is None: print("Downloading and loading models from Hugging Face...") # Paths to files in the repo repo_id = "Samara369/SER_1" ssl_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ssl.pt") pool_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_pool.pt") ser_path = hf_hub_download(repo_id=repo_id, filename="trained_models/final_ser.pt") norm_path = hf_hub_download(repo_id=repo_id, filename="trained_models/train_norm_stat.pkl") # Load SSL model ssl_model = AutoModel.from_pretrained("microsoft/wavlm-large") ssl_model.load_state_dict(torch.load(ssl_path, map_location="cpu")) ssl_model.eval() # Load pooling model feat_dim = ssl_model.config.hidden_size pool_net = getattr(net, POOLING_TYPE) pool_model = pool_net(feat_dim) pool_model.load_state_dict(torch.load(pool_path, map_location="cpu")) pool_model.eval() # Load regression head dh_input_dim = feat_dim * 2 if POOLING_TYPE == "AttentiveStatisticsPooling" else feat_dim ser_model = net.EmotionRegression(dh_input_dim, HEAD_DIM, 1, 8, dropout=0.5) ser_model.load_state_dict(torch.load(ser_path, map_location="cpu")) ser_model.eval() # Load normalization stats wav_mean, wav_std = utils.load_norm_stat(norm_path) print("Models loaded from Hugging Face.") def process_single_audio(wav_path): wav, _ = librosa.load(wav_path, sr=16000) wav = torch.tensor(wav).float() wav = (wav - wav_mean) / wav_std wav = wav.unsqueeze(0) mask = torch.ones_like(wav) return wav, mask def predict_emotion(audio_path): try: x, mask = process_single_audio(audio_path) with torch.no_grad(): ssl_out = ssl_model(x, attention_mask=mask).last_hidden_state pooled = pool_model(ssl_out, mask) logits = ser_model(pooled).cpu().numpy().flatten() emotion_index = int(np.argmax(logits)) predicted_emotion = EMOTION_MAP[emotion_index] confidence = float(logits[emotion_index]) return predicted_emotion, confidence, logits except Exception as e: print("Error during inference:", e) return "Error", 0.0, None def process_audio_file(audio_file): if audio_file is None: return "No audio file provided", 0.0, "" load_models() emotion_code, confidence, probabilities = predict_emotion(audio_file) if emotion_code == "Error": return "Error", 0.0, "❌ Error processing audio." emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code) result_text = f"🎵 **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n" result_text += f"🎯 **Confidence:** {confidence:.4f}\n\n" result_text += "**All Emotion Scores:**\n" for i, score in enumerate(probabilities): name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i]) result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n" return emotion_name, confidence, result_text def process_live_audio(audio_data): if audio_data is None: return "No audio recorded", 0.0, "" import soundfile as sf temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") temp_file.close() sf.write(temp_file.name, audio_data[1], audio_data[0]) load_models() emotion_code, confidence, probabilities = predict_emotion(temp_file.name) os.unlink(temp_file.name) if emotion_code == "Error": return "Error", 0.0, "❌ Error processing live audio." emotion_name = EMOTION_NAMES.get(emotion_code, emotion_code) result_text = f"🎤 **Predicted Emotion:** {emotion_name} ({emotion_code})\n\n" result_text += f"🎯 **Confidence:** {confidence:.4f}\n\n" result_text += "**All Emotion Scores:**\n" for i, score in enumerate(probabilities): name = EMOTION_NAMES.get(EMOTION_MAP[i], EMOTION_MAP[i]) result_text += f"- {name} ({EMOTION_MAP[i]}): {score:.4f}\n" return emotion_name, confidence, result_text def create_ser_interface(): with gr.Blocks(title="Speech Emotion Recognition") as interface: gr.Markdown("## 🎙️ Speech Emotion Recognition\nUpload or record audio to detect emotion.") with gr.Row(): with gr.Column(): file_input = gr.Audio(label="Upload Audio File", type="filepath") file_btn = gr.Button("Analyze Uploaded Audio") gr.Markdown("---") live_audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Record Live Audio") live_btn = gr.Button("Analyze Live Audio") with gr.Column(): emotion_output = gr.Textbox(label="Predicted Emotion", interactive=False) confidence_output = gr.Slider(minimum=0, maximum=1, label="Confidence", interactive=False) results_output = gr.Markdown(label="Results") file_btn.click(process_audio_file, inputs=[file_input], outputs=[emotion_output, confidence_output, results_output]) live_btn.click(process_live_audio, inputs=[live_audio], outputs=[emotion_output, confidence_output, results_output]) return interface if __name__ == "__main__": interface = create_ser_interface() interface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)