RITISHREE / app.py
sancho10's picture
Update app.py
849911b verified
raw
history blame
2.19 kB
import gradio as gr
import numpy as np
import tensorflow as tf
import librosa
import librosa.util
import pickle
from sklearn.preprocessing import LabelEncoder
# Feature extraction function
def extract_features(file_path):
try:
# Load the audio file
y, sr = librosa.load(file_path, sr=8000) # Resample to 8kHz
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
# Pad or truncate to 100 frames along axis 1
mfcc = librosa.util.fix_length(mfcc, size=100, axis=1)
# Ensure the shape is (13, 100)
if mfcc.shape[0] != 13:
mfcc = librosa.util.fix_length(mfcc, size=13, axis=0)
return {"mfcc": mfcc}
except Exception as e:
raise ValueError(f"Error in feature extraction: {str(e)}")
# Prediction function
def predict_class(file_path, model, label_encoder):
try:
features = extract_features(file_path)
mfcc = features["mfcc"]
# Add batch and channel dimensions for model compatibility
mfcc = mfcc[np.newaxis, ..., np.newaxis] # Shape: (1, 13, 100, 1)
# Make prediction
prediction = model.predict(mfcc)
predicted_class = label_encoder.inverse_transform([np.argmax(prediction)])
return f"Predicted Class: {predicted_class[0]}"
except Exception as e:
return f"Error in prediction: {str(e)}"
# Load the pre-trained model
model = tf.keras.models.load_model("voice_classification_modelm.h5")
# Load the label encoder
with open("label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
# Define the Gradio function
def classify_audio(audio_file):
return predict_class(audio_file, model, label_encoder)
# Create the Gradio interface
interface = gr.Interface(
fn=classify_audio,
inputs=gr.Audio(source="upload", type="filepath", label="Upload an Audio File"),
outputs=gr.Textbox(label="Predicted Class"),
title="Voice Disorder Classification",
description="Upload an audio file to classify its voice type (e.g., healthy or various disorder types).",
examples=["example_audio.wav"], # Replace with paths to example audio files
)
# Launch the Gradio app
interface.launch()