sancho10 commited on
Commit
849911b
·
verified ·
1 Parent(s): d220c73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -61
app.py CHANGED
@@ -1,61 +1,65 @@
1
- import gradio as gr
2
- import numpy as np
3
- import tensorflow as tf
4
- import librosa
5
- import librosa.util
6
-
7
- # Define your predict_class function
8
- def predict_class(file_path, model, labels):
9
- # Extract MFCC features
10
- y, sr = librosa.load(file_path, sr=None)
11
- mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
12
-
13
- # Pad or truncate to 100 frames along axis 1
14
- mfcc = librosa.util.fix_length(mfcc, size=100, axis=1)
15
-
16
- # Ensure mfcc has shape (13, 100)
17
- if mfcc.shape[0] != 13:
18
- mfcc = librosa.util.fix_length(mfcc, size=13, axis=0)
19
-
20
- # Add batch and channel dimensions
21
- mfcc = mfcc[np.newaxis, ..., np.newaxis] # Shape: (1, 13, 100, 1)
22
-
23
- # Predict using the model
24
- prediction = model.predict(mfcc)
25
- predicted_class = labels[np.argmax(prediction)]
26
- return predicted_class
27
-
28
- # Load your pre-trained model
29
- model = tf.keras.models.load_model("voice_classification_modelm.h5")
30
-
31
- # Define the class labels based on your folder names
32
- labels = [
33
- "all_vowels_healthy",
34
- "allvowels_functional",
35
- "allvowels_laryngitis",
36
- "allvowels_lukoplakia",
37
- "allvowels_psychogenic",
38
- "allvowels_rlnp",
39
- "allvowels_sd"
40
- ]
41
-
42
- # Define the Gradio function
43
- def classify_audio(audio_file):
44
- try:
45
- predicted_class = predict_class(audio_file, model, labels)
46
- return f"Predicted Class: {predicted_class}"
47
- except Exception as e:
48
- return f"Error: {str(e)}"
49
-
50
- # Create the Gradio interface
51
- interface = gr.Interface(
52
- fn=classify_audio,
53
- inputs=gr.Audio(source="upload", type="filepath", label="Upload an Audio File"),
54
- outputs=gr.Textbox(label="Predicted Class"),
55
- title="Voice Classification",
56
- description="Upload an audio file to classify its voice type.",
57
- examples=["example_audio.wav"] # Replace with paths to sample audio files
58
- )
59
-
60
- # Launch the app
61
- interface.launch()
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import librosa
5
+ import librosa.util
6
+ import pickle
7
+ from sklearn.preprocessing import LabelEncoder
8
+
9
+ # Feature extraction function
10
+ def extract_features(file_path):
11
+ try:
12
+ # Load the audio file
13
+ y, sr = librosa.load(file_path, sr=8000) # Resample to 8kHz
14
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
15
+
16
+ # Pad or truncate to 100 frames along axis 1
17
+ mfcc = librosa.util.fix_length(mfcc, size=100, axis=1)
18
+
19
+ # Ensure the shape is (13, 100)
20
+ if mfcc.shape[0] != 13:
21
+ mfcc = librosa.util.fix_length(mfcc, size=13, axis=0)
22
+
23
+ return {"mfcc": mfcc}
24
+ except Exception as e:
25
+ raise ValueError(f"Error in feature extraction: {str(e)}")
26
+
27
+ # Prediction function
28
+ def predict_class(file_path, model, label_encoder):
29
+ try:
30
+ features = extract_features(file_path)
31
+ mfcc = features["mfcc"]
32
+
33
+ # Add batch and channel dimensions for model compatibility
34
+ mfcc = mfcc[np.newaxis, ..., np.newaxis] # Shape: (1, 13, 100, 1)
35
+
36
+ # Make prediction
37
+ prediction = model.predict(mfcc)
38
+ predicted_class = label_encoder.inverse_transform([np.argmax(prediction)])
39
+ return f"Predicted Class: {predicted_class[0]}"
40
+ except Exception as e:
41
+ return f"Error in prediction: {str(e)}"
42
+
43
+ # Load the pre-trained model
44
+ model = tf.keras.models.load_model("voice_classification_modelm.h5")
45
+
46
+ # Load the label encoder
47
+ with open("label_encoder.pkl", "rb") as f:
48
+ label_encoder = pickle.load(f)
49
+
50
+ # Define the Gradio function
51
+ def classify_audio(audio_file):
52
+ return predict_class(audio_file, model, label_encoder)
53
+
54
+ # Create the Gradio interface
55
+ interface = gr.Interface(
56
+ fn=classify_audio,
57
+ inputs=gr.Audio(source="upload", type="filepath", label="Upload an Audio File"),
58
+ outputs=gr.Textbox(label="Predicted Class"),
59
+ title="Voice Disorder Classification",
60
+ description="Upload an audio file to classify its voice type (e.g., healthy or various disorder types).",
61
+ examples=["example_audio.wav"], # Replace with paths to example audio files
62
+ )
63
+
64
+ # Launch the Gradio app
65
+ interface.launch()