| import warnings | |
| import spaces | |
| warnings.filterwarnings("ignore") | |
| import logging | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from transformers import AutoModel | |
| import laion_clap | |
| from meanaudio.eval_utils import ( | |
| ModelConfig, | |
| all_model_cfg, | |
| generate_mf, | |
| generate_fm, | |
| setup_eval_logging, | |
| ) | |
| from meanaudio.model.flow_matching import FlowMatching | |
| from meanaudio.model.mean_flow import MeanFlow | |
| from meanaudio.model.networks import MeanAudio, get_mean_audio | |
| from meanaudio.model.utils.features_utils import FeaturesUtils | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| import gc | |
| import json | |
| from datetime import datetime | |
| from huggingface_hub import snapshot_download | |
| import numpy as np | |
| log = logging.getLogger() | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| setup_eval_logging() | |
| OUTPUT_DIR = Path("./output/gradio") | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| NUM_SAMPLE = 1 | |
| # 创建RLHF反馈数据目录 | |
| FEEDBACK_DIR = Path("./rlhf") | |
| FEEDBACK_DIR.mkdir(exist_ok=True) | |
| FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl" | |
| # Global model cache to avoid reloading | |
| MODEL_CACHE = {} | |
| FEATURE_UTILS_CACHE = {} | |
| def fade_out(x, sr, fade_ms=50): | |
| n = len(x) | |
| k = int(sr * fade_ms / 1000) | |
| if k <= 0 or k >= n: | |
| return x | |
| w = np.linspace(1.0, 0.0, k) | |
| x[-k:] = x[-k:] * w | |
| return x | |
| def ensure_models_downloaded(): | |
| for variant, model_cfg in all_model_cfg.items(): | |
| if not model_cfg.model_path.exists(): | |
| log.info(f'Model {variant} not found, downloading...') | |
| snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights") | |
| break | |
| def load_model_cache(): | |
| for variant in all_model_cfg.keys(): | |
| if variant in MODEL_CACHE: | |
| return MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] | |
| else: | |
| log.info(f"Loading model {variant} for the first time...") | |
| model_cfg = all_model_cfg[variant] | |
| net = get_mean_audio(model_cfg.model_name, use_rope=True, text_c_dim=512) | |
| net = net.to(device, torch.bfloat16).eval() | |
| net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True)) | |
| MODEL_CACHE[variant] = net | |
| feature_utils = FeaturesUtils( | |
| tod_vae_ckpt=model_cfg.vae_path, | |
| enable_conditions=True, | |
| encoder_name="t5_clap", | |
| mode=model_cfg.mode, | |
| bigvgan_vocoder_ckpt=model_cfg.bigvgan_16k_path, | |
| need_vae_encoder=False | |
| ).to(device, torch.bfloat16).eval() | |
| FEATURE_UTILS_CACHE['default'] = feature_utils | |
| def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""): | |
| feedback_data = { | |
| "timestamp": datetime.now().isoformat(), | |
| "prompt": prompt, | |
| "audio1_path": audio1_path, | |
| "audio2_path": audio2_path, | |
| "preference": preference, # "audio1", "audio2", "equal", "both_bad" | |
| "additional_comment": additional_comment | |
| } | |
| with open(FEEDBACK_FILE, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n") | |
| log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'") | |
| return f"✅ Thanks for your feedback, preference recorded: {preference}" | |
| def generate_audio_gradio( | |
| prompt, | |
| duration, | |
| cfg_strength, | |
| num_steps, | |
| variant, | |
| seed | |
| ): | |
| # update | |
| if duration <= 0 or num_steps <= 0: | |
| raise ValueError("Duration and number of steps must be positive.") | |
| if variant not in all_model_cfg: | |
| raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}") | |
| net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] | |
| model = all_model_cfg[variant] | |
| seq_cfg = model.seq_cfg | |
| seq_cfg.duration = duration | |
| net.update_seq_lengths(seq_cfg.latent_seq_len) | |
| if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full': | |
| use_meanflow=True | |
| elif variant == 'fluxaudio_s_full': | |
| use_meanflow=False | |
| if use_meanflow: | |
| sampler = MeanFlow(steps=num_steps) | |
| log.info("Using MeanFlow for generation.") | |
| generation_func = generate_mf | |
| sampler_arg_name = "mf" | |
| cfg_strength = 0 | |
| else: | |
| sampler = FlowMatching( | |
| min_sigma=0, inference_mode="euler", num_steps=num_steps | |
| ) | |
| log.info("Using FlowMatching for generation.") | |
| generation_func = generate_fm | |
| sampler_arg_name = "fm" | |
| rng = torch.Generator(device=device) | |
| rng.manual_seed(seed) | |
| audios = generation_func( | |
| [prompt]*NUM_SAMPLE, | |
| negative_text=None, | |
| feature_utils=feature_utils, | |
| net=net, | |
| rng=rng, | |
| cfg_strength=cfg_strength, | |
| **{sampler_arg_name: sampler}, | |
| ) | |
| save_paths = [] | |
| safe_prompt = ( | |
| "".join(c for c in prompt if c.isalnum() or c in (" ", "_")) | |
| .rstrip() | |
| .replace(" ", "_")[:50] | |
| ) | |
| for i, audio in enumerate(audios): | |
| audio = audio.float().cpu() | |
| audio = fade_out(audio, seq_cfg.sampling_rate, fade_ms=100) | |
| current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| filename = f"{safe_prompt}_{current_time_string}_{i}.flac" | |
| save_path = OUTPUT_DIR / filename | |
| torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate) | |
| log.info(f"Audio saved to {save_path}") | |
| save_paths.append(str(save_path)) | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| return save_paths[0], prompt | |
| # Gradio input and output components | |
| input_text = gr.Textbox(lines=2, label="Prompt") | |
| output_audio = gr.Audio(label="Generated Audio", type="filepath") | |
| denoising_steps = gr.Slider(minimum=1, maximum=25, value=1, step=1, label="Sampling Steps", interactive=True) | |
| cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True) | |
| duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True) | |
| seed = gr.Slider(minimum=1, maximum=100, value=42, step=1, label="Seed", interactive=True) | |
| variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()), value='meanaudio_s_full', interactive=True) | |
| # description_text = """ | |
| # **MeanAudio** is a novel text-to-audio generator that uses **MeanFlow** to synthesize realistic and faithful audio in few sampling steps. It achieves state-of-the-art performance in single-step audio generation and delivers strong performance in multi-step audio generation. | |
| # <p align="center"> | |
| # <a href="https://huggingface.co/AndreasXi/MeanAudio"> | |
| # <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model-HuggingFace-violet" alt="HuggingFace Model"> | |
| # </a> | |
| # <a href="https://huggingface.co/spaces/chenxie95/MeanAudio"> | |
| # <img src="https://img.shields.io/badge/%F0%9F%9A%80%20Space-HuggingFace-8A2BE2" alt="HuggingFace Space"> | |
| # </a> | |
| # <a href="https://meanaudio.github.io/"> | |
| # <img src="https://img.shields.io/badge/%F0%9F%93%84%20Project-Page-brightred" alt="Project Page"> | |
| # </a> | |
| # <a href="https://github.com/xiquan-li/MeanAudio"> | |
| # <img src="https://img.shields.io/badge/%F0%9F%92%BB%20Code-GitHub-black" alt="GitHub"> | |
| # </a> | |
| # </p> | |
| # """ | |
| description_text = """ | |
| ### **MeanAudio** is a novel text-to-audio generator that uses **MeanFlow** to synthesize realistic and faithful audio in few sampling steps. It achieves state-of-the-art performance in single-step audio generation and delivers strong performance in multi-step audio generation. | |
| ### [📖 **Arxiv**](https://arxiv.org/abs/2508.06098) | [💻 **GitHub**](https://github.com/xiquan-li/MeanAudio) | [🤗 **Model**](https://huggingface.co/AndreasXi/MeanAudio) | [🚀 **Space**](https://huggingface.co/spaces/chenxie95/MeanAudio) | [🌐 **Project Page**](https://meanaudio.github.io/) | |
| """ | |
| gr_interface = gr.Interface( | |
| fn=generate_audio_gradio, | |
| inputs=[input_text, duration, cfg_strength, denoising_steps, variant, seed], | |
| outputs=[ | |
| gr.Audio(label="🎵 Audio Sample", type="filepath"), | |
| gr.Textbox(label="Prompt Used", interactive=False) | |
| ], | |
| title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", | |
| description=description_text, | |
| flagging_mode="never", | |
| examples=[ | |
| ["Guitar and piano playing a warm music, with a soft and gentle melody, perfect for a romantic evening.", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["Quiet speech and then and airplane flying away", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["A basketball bounces rhythmically on a court, shoes squeak against the floor, and a referee’s whistle cuts through the air", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["Chopping meat on a wooden table.", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["A vehicle engine revving then accelerating at a high rate as a metal surface is whipped followed by tires skidding.", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["Battlefield scene, continuous roar of artillery and gunfire, high fidelity, the sharp crack of bullets, the thundering explosions of bombs, and the screams of wounded soldiers.", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["Pop music that upbeat, catchy, and easy to listen, high fidelity, with simple melodies, electronic instruments and polished production.", 10, 3, 1, "meanaudio_s_full", 42], | |
| ["A fast-paced instrumental piece with a classical vibe featuring stringed instruments, evoking an energetic and uplifting mood.", 10, 3, 1, "meanaudio_s_full", 42] | |
| ], | |
| cache_examples="lazy", | |
| ) | |
| if __name__ == "__main__": | |
| # ensure_models_downloaded() | |
| # load_model_cache() | |
| gr_interface.queue(15).launch() | |
| # theme = gr.themes.Soft( | |
| # primary_hue="blue", | |
| # secondary_hue="slate", | |
| # neutral_hue="slate", | |
| # text_size="sm", | |
| # spacing_size="sm", | |
| # ).set( | |
| # background_fill_primary="*neutral_50", | |
| # background_fill_secondary="*background_fill_primary", | |
| # block_background_fill="*background_fill_primary", | |
| # block_border_width="0px", | |
| # panel_background_fill="*neutral_50", | |
| # panel_border_width="0px", | |
| # input_background_fill="*neutral_100", | |
| # input_border_color="*neutral_200", | |
| # button_primary_background_fill="*primary_300", | |
| # button_primary_background_fill_hover="*primary_400", | |
| # button_secondary_background_fill="*neutral_200", | |
| # button_secondary_background_fill_hover="*neutral_300", | |
| # ) | |
| # custom_css = """ | |
| # #main-headertitle { | |
| # text-align: center; | |
| # margin-top: 15px; | |
| # margin-bottom: 10px; | |
| # color: var(--neutral-600); | |
| # font-weight: 600; | |
| # } | |
| # #main-header { | |
| # text-align: center; | |
| # margin-top: 5px; | |
| # margin-bottom: 10px; | |
| # color: var(--neutral-600); | |
| # font-weight: 600; | |
| # } | |
| # #model-settings-header, #generation-settings-header { | |
| # color: var(--neutral-600); | |
| # margin-top: 8px; | |
| # margin-bottom: 8px; | |
| # font-weight: 500; | |
| # font-size: 1.1em; | |
| # } | |
| # .setting-section { | |
| # padding: 10px 12px; | |
| # border-radius: 6px; | |
| # background-color: var(--neutral-50); | |
| # margin-bottom: 10px; | |
| # border: 1px solid var(--neutral-100); | |
| # } | |
| # hr { | |
| # border: none; | |
| # height: 1px; | |
| # background-color: var(--neutral-200); | |
| # margin: 8px 0; | |
| # } | |
| # #generate-btn { | |
| # width: 100%; | |
| # max-width: 250px; | |
| # margin: 10px auto; | |
| # display: block; | |
| # padding: 10px 15px; | |
| # font-size: 16px; | |
| # border-radius: 5px; | |
| # } | |
| # #status-box { | |
| # min-height: 50px; | |
| # display: flex; | |
| # align-items: center; | |
| # justify-content: center; | |
| # padding: 8px; | |
| # border-radius: 5px; | |
| # border: 1px solid var(--neutral-200); | |
| # color: var(--neutral-700); | |
| # } | |
| # #project-badges { | |
| # text-align: center; | |
| # margin-top: 30px; | |
| # margin-bottom: 20px; | |
| # } | |
| # #project-badges #badge-container { | |
| # display: flex; | |
| # gap: 10px; | |
| # align-items: center; | |
| # justify-content: center; | |
| # flex-wrap: wrap; | |
| # } | |
| # #project-badges img { | |
| # border-radius: 5px; | |
| # box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); | |
| # height: 20px; | |
| # transition: transform 0.1s ease, box-shadow 0.1s ease; | |
| # } | |
| # #project-badges a:hover img { | |
| # transform: translateY(-2px); | |
| # box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); | |
| # } | |
| # #audio-output { | |
| # height: 200px; | |
| # border-radius: 5px; | |
| # border: 1px solid var(--neutral-200); | |
| # } | |
| # .gradio-dropdown label, .gradio-checkbox label, .gradio-number label, .gradio-textbox label { | |
| # font-weight: 500; | |
| # color: var(--neutral-700); | |
| # font-size: 0.9em; | |
| # } | |
| # .gradio-row { | |
| # gap: 8px; | |
| # } | |
| # .gradio-block { | |
| # margin-bottom: 8px; | |
| # } | |
| # .setting-section .gradio-block { | |
| # margin-bottom: 6px; | |
| # } | |
| # ::-webkit-scrollbar { | |
| # width: 8px; | |
| # height: 8px; | |
| # } | |
| # ::-webkit-scrollbar-track { | |
| # background: var(--neutral-100); | |
| # border-radius: 4px; | |
| # } | |
| # ::-webkit-scrollbar-thumb { | |
| # background: var(--neutral-300); | |
| # border-radius: 4px; | |
| # } | |
| # ::-webkit-scrollbar-thumb:hover { | |
| # background: var(--neutral-400); | |
| # } | |
| # * { | |
| # scrollbar-width: thin; | |
| # scrollbar-color: var(--neutral-300) var(--neutral-100); | |
| # } | |
| # """ | |
| # with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo: | |
| # gr.Markdown("# MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", elem_id="main-header") | |
| # badge_html = ''' | |
| # <div id="project-badges"> <!-- 使用 ID | |
| # 以便应用 CSS --> | |
| # <div id="badge-container"> <!-- 添加这个容器 div 并使用 ID --> | |
| # <a href="https://huggingface.co/junxiliu/MeanAudio"> | |
| # <img src="https://img.shields.io/badge/Model-HuggingFace-violet?logo=huggingface" alt="Hugging Face Model"> | |
| # </a> | |
| # <a href="https://huggingface.co/spaces/chenxie95/MeanAudio"> | |
| # <img src="https://img.shields.io/badge/Space-HuggingFace-8A2BE2?logo=huggingface" alt="Hugging Face Space"> | |
| # </a> | |
| # <a href="https://meanaudio.github.io/"> | |
| # <img src="https://img.shields.io/badge/Project-Page-brightred?style=flat" alt="Project Page"> | |
| # </a> | |
| # <a href="https://github.com/xiquan-li/MeanAudio"> | |
| # <img src="https://img.shields.io/badge/Code-GitHub-black?logo=github" alt="GitHub"> | |
| # </a> | |
| # </div> | |
| # </div> | |
| # ''' | |
| # gr.HTML(badge_html) | |
| # with gr.Column(elem_classes="setting-section"): | |
| # with gr.Row(): | |
| # available_variants = ( | |
| # list(all_model_cfg.keys()) if all_model_cfg else [] | |
| # ) | |
| # default_variant = ( | |
| # 'meanaudio_mf' | |
| # ) | |
| # variant = gr.Dropdown( | |
| # label="Model Variant", | |
| # choices=available_variants, | |
| # value=default_variant, | |
| # interactive=True, | |
| # scale=3, | |
| # ) | |
| # with gr.Column(elem_classes="setting-section"): | |
| # with gr.Row(): | |
| # prompt = gr.Textbox( | |
| # label="Prompt", | |
| # placeholder="Describe the sound you want to generate...", | |
| # scale=1, | |
| # ) | |
| # negative_prompt = gr.Textbox( | |
| # label="Negative Prompt", | |
| # placeholder="Describe sounds you want to avoid...", | |
| # value="", | |
| # scale=1, | |
| # ) | |
| # with gr.Row(): | |
| # duration = gr.Number( | |
| # label="Duration (sec)", value=10.0, minimum=0.1, scale=1 | |
| # ) | |
| # cfg_strength = gr.Number( | |
| # label="CFG (Meanflow forced to 3)", value=3, minimum=0.0, scale=1 | |
| # ) | |
| # with gr.Row(): | |
| # seed = gr.Number( | |
| # label="Seed (-1 for random)", value=42, precision=0, scale=1 | |
| # ) | |
| # num_steps = gr.Number( | |
| # label="Number of Steps", | |
| # value=1, | |
| # precision=0, | |
| # minimum=1, | |
| # scale=1, | |
| # ) | |
| # generate_button = gr.Button("Generate", variant="primary", elem_id="generate-btn") | |
| # generate_output_text = gr.Textbox( | |
| # label="Result Status", interactive=False, elem_id="status-box" | |
| # ) | |
| # audio_output = gr.Audio( | |
| # label="Generated Audio", type="filepath", elem_id="audio-output" | |
| # ) | |
| # generate_button.click( | |
| # fn=generate_audio_gradio, | |
| # inputs=[ | |
| # prompt, | |
| # negative_prompt, | |
| # duration, | |
| # cfg_strength, | |
| # num_steps, | |
| # seed, | |
| # variant, | |
| # ], | |
| # outputs=[generate_output_text, audio_output], | |
| # ) | |
| # audio_examples = [ | |
| # ["Typing on a keyboard", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
| # ["A man speaks followed by a popping noise and laughter", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
| # ["Some humming followed by a toilet flushing", "", 10.0, 3, 2, 42, "meanaudio_mf"], | |
| # ["Rain falling on a hard surface as thunder roars in the distance", "", 10.0, 3, 5, 42, "meanaudio_mf"], | |
| # ["Food sizzling and oil popping", "", 10.0, 3, 25, 42, "meanaudio_mf"], | |
| # ["Pots and dishes clanking as a man talks followed by liquid pouring into a container", "", 8.0, 3, 2, 42, "meanaudio_mf"], | |
| # ["A few seconds of silence then a rasping sound against wood", "", 12.0, 3, 2, 42, "meanaudio_mf"], | |
| # ["A man speaks as he gives a speech and then the crowd cheers", "", 10.0, 3, 25, 42, "fluxaudio_fm"], | |
| # ["A goat bleating repeatedly", "", 10.0, 3, 50, 123, "fluxaudio_fm"], | |
| # ["A speech and gunfire followed by a gun being loaded", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
| # ["Tires squealing followed by an engine revving", "", 12.0, 4, 25, 456, "fluxaudio_fm"], | |
| # ["Hammer slowly hitting the wooden table", "", 10.0, 3.5, 25, 42, "fluxaudio_fm"], | |
| # ["Dog barking excitedly and man shouting as race car engine roars past", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
| # ["A dog barking and a cat mewing and a racing car passes by", "", 12.0, 3, 5, -1, "meanaudio_mf"], | |
| # ["Whistling with birds chirping", "", 10.0, 4, 50, 42, "fluxaudio_fm"], | |
| # ] | |
| # gr.Examples( | |
| # examples=audio_examples, | |
| # inputs=[prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant], | |
| # #outputs=[generate_output_text, audio_output], | |
| # #fn=generate_audio_gradio, | |
| # examples_per_page=5, | |
| # label="Example Prompts", | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |