File size: 3,544 Bytes
c914273
 
 
 
 
 
7b37b0e
c914273
 
 
 
 
7b37b0e
 
c914273
7b37b0e
 
c914273
 
 
 
 
 
 
7b37b0e
c914273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from preprocessing.preprocess import AudioPipeline
from preprocessing.preprocess import AudioPipeline
from models.residual import ResidualDancer
import os
import json
from functools import cache
import pandas as pd



@cache
def get_model(device) -> tuple[ResidualDancer, np.ndarray]:
    model_path = "models/weights/ResidualDancer"
    weights = os.path.join(model_path, "dancer_net.pt")
    config_path = os.path.join(model_path, "config.json")

    with open(config_path) as f:
        config = json.load(f)
    labels = np.array(sorted(config["classes"]))

    model = ResidualDancer(n_classes=len(labels))
    model.load_state_dict(torch.load(weights))
    model = model.to(device).eval()
    return model, labels

@cache
def get_pipeline(sample_rate:int) -> AudioPipeline:
    return AudioPipeline(input_freq=sample_rate)

@cache
def get_dance_map() -> dict:
    df = pd.read_csv("data/dance_mapping.csv")
    return df.set_index("id").to_dict()["name"]


def predict(audio: tuple[int, np.ndarray]) -> list[str]:
    sample_rate, waveform = audio
    
    expected_duration = 6
    threshold = 0.5
    sample_len = sample_rate * expected_duration
    device = "mps"

    audio_pipeline = get_pipeline(sample_rate)
    model, labels = get_model(device)

    if sample_len > len(waveform):
        raise gr.Error("You must record for at least 6 seconds")
    if len(waveform.shape) > 1 and waveform.shape[1] > 1:
        waveform = waveform.transpose(1,0)
        waveform = waveform.mean(axis=0, keepdims=True)
    else:
        waveform = np.expand_dims(waveform, 0)
    waveform = waveform[: ,:sample_len]
    waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
    waveform = waveform.astype("float32")
    waveform = torch.from_numpy(waveform)
    spectrogram = audio_pipeline(waveform)
    spectrogram = spectrogram.unsqueeze(0).to(device)

    with torch.no_grad():
        results = model(spectrogram)
    dance_mapping = get_dance_map()
    results = results.squeeze(0).detach().cpu().numpy()
    result_mask = results > threshold
    probs = results[result_mask]
    dances = labels[result_mask]
    
    return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."


def demo():
    title = "Dance Classifier"
    description = "Record 6 seconds of a song and find out what dance fits the music."
    with gr.Blocks() as app:
        gr.Markdown(f"# {title}")
        gr.Markdown(description)
        with gr.Tab("Record Song"):
            mic_audio = gr.Audio(source="microphone", label="Song Recording")
            mic_submit = gr.Button("Predict")
            
        with gr.Tab("Upload Song") as t:
            audio_file = gr.Audio(label="Song Audio File")
            audio_file_submit = gr.Button("Predict")
        song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
        example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
        
        labels = gr.Label(label="Dances")

        gr.Markdown("## Examples")
        gr.Examples(
            examples=example_audio,
            inputs=audio_file,
            outputs=labels,
            fn=predict,
            )

        audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
        mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)

    return app


if __name__ == "__main__":
    demo().launch()