Spaces:
Paused
Paused
| import gradio as gr | |
| import torch as T | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import matplotlib.pyplot as plt | |
| from utils import load_ckpt, print_colored | |
| from tokenizer import make_tokenizer | |
| from model import get_hertz_dev_config | |
| from typing import Tuple | |
| import numpy as np | |
| import os | |
| # Global variables for model and tokenizer | |
| global_generator = None | |
| global_tokenizer = None | |
| default_audio_path = "sample.wav" # Changed from "testingtesting.wav" | |
| def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]: | |
| """Initialize the model and tokenizer""" | |
| global global_generator, global_tokenizer | |
| if global_generator is not None and global_tokenizer is not None: | |
| return global_generator, global_tokenizer | |
| device = 'cuda' if T.cuda.is_available() else 'cpu' | |
| T.cuda.set_device(0) if device == 'cuda' else None | |
| print_colored("Initializing model and tokenizer...", "blue") | |
| global_tokenizer = make_tokenizer(device) | |
| model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation) | |
| global_generator = model_config() | |
| global_generator = global_generator.eval().to(T.bfloat16).to(device) | |
| print_colored("Model initialization complete!", "green") | |
| return global_generator, global_tokenizer | |
| def process_audio(audio_path: str, sr: int) -> T.Tensor: | |
| """Load and preprocess audio file""" | |
| audio_tensor, sr = torchaudio.load(audio_path) | |
| if audio_tensor.shape[0] == 2: | |
| audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0) | |
| if sr != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) | |
| audio_tensor = resampler(audio_tensor) | |
| max_samples = 16000 * 60 * 5 # 5 minutes | |
| if audio_tensor.shape[1] > max_samples: | |
| audio_tensor = audio_tensor[:, :max_samples] | |
| return audio_tensor.unsqueeze(0) | |
| def generate_completion( | |
| audio_file, | |
| prompt_len_seconds: float = 3.0, | |
| num_completions: int = 5, | |
| generation_seconds: float = 20.0, | |
| token_temp: float = 0.8, | |
| categorical_temp: float = 0.5, | |
| gaussian_temp: float = 0.1, | |
| progress=gr.Progress(track_tqdm=True) | |
| ) -> list: | |
| """Generate audio completions from the input audio""" | |
| device = 'cuda' if T.cuda.is_available() else 'cpu' | |
| # Use existing model and tokenizer | |
| generator, audio_tokenizer = global_generator, global_tokenizer | |
| progress(0, desc="Processing input audio...") | |
| # Process input audio | |
| prompt_audio = process_audio(audio_file, sr=16000) | |
| prompt_len = int(prompt_len_seconds * 8) | |
| progress(0.2, desc="Encoding prompt...") | |
| # Encode prompt | |
| with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
| encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device)) | |
| completions = [] | |
| for i in range(num_completions): | |
| progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}") | |
| # Generate completion | |
| encoded_prompt = encoded_prompt_audio[:, :prompt_len] | |
| with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
| completed_audio_batch = generator.completion( | |
| encoded_prompt, | |
| temps=(token_temp, (categorical_temp, gaussian_temp)), | |
| use_cache=True, | |
| gen_len=int(generation_seconds * 8) | |
| ) | |
| decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16()) | |
| # Process audio for output | |
| audio_tensor = decoded_completion.cpu().squeeze() | |
| if audio_tensor.ndim == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| audio_tensor = audio_tensor.float() | |
| if audio_tensor.abs().max() > 1: | |
| audio_tensor = audio_tensor / audio_tensor.abs().max() | |
| # Trim to include only the generated portion | |
| output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):] | |
| completions.append((16000, output_audio.numpy().T)) | |
| progress(1.0, desc="Generation complete!") | |
| return completions | |
| def create_interface(): | |
| # Initialize model at startup | |
| init_model() | |
| with gr.Blocks(title="Audio Completion Generator") as app: | |
| gr.Markdown(""" | |
| # Audio Completion Generator | |
| Upload an audio file (or use the default) and generate AI completions based on the prompt. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Load the default audio if it exists | |
| default_value = default_audio_path if os.path.exists(default_audio_path) else None | |
| audio_input = gr.Audio( | |
| label="Input Audio", | |
| type="filepath", | |
| sources=["microphone", "upload"], | |
| value=default_value | |
| ) | |
| with gr.Row(): | |
| prompt_len = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=0.5, | |
| label="Prompt Length (seconds)" | |
| ) | |
| default_num_completions = 5 | |
| num_completions = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=default_num_completions, | |
| step=1, | |
| label="Number of Completions" | |
| ) | |
| gen_length = gr.Slider( | |
| minimum=5, | |
| maximum=60, | |
| value=20, | |
| step=5, | |
| label="Generation Length (seconds)" | |
| ) | |
| with gr.Row(): | |
| token_temp = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Token Temperature" | |
| ) | |
| cat_temp = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.1, | |
| label="Categorical Temperature" | |
| ) | |
| gauss_temp = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.1, | |
| label="Gaussian Temperature" | |
| ) | |
| generate_btn = gr.Button("Generate Completions") | |
| status_text = gr.Markdown("Ready") | |
| with gr.Column(): | |
| output_audios = [] | |
| for i in range(10): # Create 10 audio components | |
| output_audios.append(gr.Audio( | |
| label=f"Generated Completion {i+1}", | |
| type="numpy", | |
| visible=False | |
| )) | |
| def update_visibility(num): | |
| return [gr.update(visible=(i < num)) for i in range(10)] | |
| def generate_with_status(*args): | |
| status_text.value = "Processing input audio..." | |
| completions = generate_completion(*args) | |
| status_text.value = "Generation complete!" | |
| # Prepare outputs for all audio components | |
| outputs = [] | |
| for i in range(10): | |
| if i < len(completions): | |
| outputs.append(completions[i]) | |
| else: | |
| outputs.append(None) | |
| return outputs | |
| # Set initial visibility on load | |
| app.load( | |
| fn=update_visibility, | |
| inputs=[num_completions], | |
| outputs=output_audios | |
| ) | |
| # Update visibility when slider changes | |
| num_completions.change( | |
| fn=update_visibility, | |
| inputs=[num_completions], | |
| outputs=output_audios | |
| ) | |
| generate_btn.click( | |
| fn=generate_with_status, | |
| inputs=[ | |
| audio_input, | |
| prompt_len, | |
| num_completions, | |
| gen_length, | |
| token_temp, | |
| cat_temp, | |
| gauss_temp | |
| ], | |
| outputs=output_audios | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_interface() | |
| app.launch(share=True) |