Spaces:
Runtime error
Runtime error
File size: 3,544 Bytes
c914273 7b37b0e c914273 7b37b0e c914273 7b37b0e c914273 7b37b0e c914273 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from preprocessing.preprocess import AudioPipeline
from preprocessing.preprocess import AudioPipeline
from models.residual import ResidualDancer
import os
import json
from functools import cache
import pandas as pd
@cache
def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
model_path = "models/weights/ResidualDancer"
weights = os.path.join(model_path, "dancer_net.pt")
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
labels = np.array(sorted(config["classes"]))
model = ResidualDancer(n_classes=len(labels))
model.load_state_dict(torch.load(weights))
model = model.to(device).eval()
return model, labels
@cache
def get_pipeline(sample_rate:int) -> AudioPipeline:
return AudioPipeline(input_freq=sample_rate)
@cache
def get_dance_map() -> dict:
df = pd.read_csv("data/dance_mapping.csv")
return df.set_index("id").to_dict()["name"]
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
sample_rate, waveform = audio
expected_duration = 6
threshold = 0.5
sample_len = sample_rate * expected_duration
device = "mps"
audio_pipeline = get_pipeline(sample_rate)
model, labels = get_model(device)
if sample_len > len(waveform):
raise gr.Error("You must record for at least 6 seconds")
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
waveform = waveform.transpose(1,0)
waveform = waveform.mean(axis=0, keepdims=True)
else:
waveform = np.expand_dims(waveform, 0)
waveform = waveform[: ,:sample_len]
waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
waveform = waveform.astype("float32")
waveform = torch.from_numpy(waveform)
spectrogram = audio_pipeline(waveform)
spectrogram = spectrogram.unsqueeze(0).to(device)
with torch.no_grad():
results = model(spectrogram)
dance_mapping = get_dance_map()
results = results.squeeze(0).detach().cpu().numpy()
result_mask = results > threshold
probs = results[result_mask]
dances = labels[result_mask]
return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."
def demo():
title = "Dance Classifier"
description = "Record 6 seconds of a song and find out what dance fits the music."
with gr.Blocks() as app:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Tab("Record Song"):
mic_audio = gr.Audio(source="microphone", label="Song Recording")
mic_submit = gr.Button("Predict")
with gr.Tab("Upload Song") as t:
audio_file = gr.Audio(label="Song Audio File")
audio_file_submit = gr.Button("Predict")
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
labels = gr.Label(label="Dances")
gr.Markdown("## Examples")
gr.Examples(
examples=example_audio,
inputs=audio_file,
outputs=labels,
fn=predict,
)
audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)
return app
if __name__ == "__main__":
demo().launch() |