add rlhf
Browse files- app.py +41 -16
- feedback_collector.py +127 -0
app.py
CHANGED
|
@@ -23,6 +23,7 @@ from meanaudio.model.utils.features_utils import FeaturesUtils
|
|
| 23 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 24 |
torch.backends.cudnn.allow_tf32 = True
|
| 25 |
import gc
|
|
|
|
| 26 |
from datetime import datetime
|
| 27 |
from huggingface_hub import snapshot_download
|
| 28 |
import numpy as np
|
|
@@ -38,6 +39,11 @@ OUTPUT_DIR = Path("./output/gradio")
|
|
| 38 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 39 |
NUM_SAMPLE = 2
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# Global model cache to avoid reloading
|
| 42 |
MODEL_CACHE = {}
|
| 43 |
FEATURE_UTILS_CACHE = {}
|
|
@@ -80,6 +86,22 @@ def load_model_cache():
|
|
| 80 |
).to(device, torch.bfloat16).eval()
|
| 81 |
FEATURE_UTILS_CACHE['default'] = feature_utils
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
@spaces.GPU(duration=60)
|
| 85 |
@torch.inference_mode()
|
|
@@ -97,7 +119,7 @@ def generate_audio_gradio(
|
|
| 97 |
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
|
| 98 |
|
| 99 |
net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
|
| 100 |
-
|
| 101 |
model = all_model_cfg[variant]
|
| 102 |
seq_cfg = model.seq_cfg
|
| 103 |
seq_cfg.duration = duration
|
|
@@ -142,21 +164,21 @@ def generate_audio_gradio(
|
|
| 142 |
|
| 143 |
audio = fade_out(audio, seq_cfg.sampling_rate)
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
if device == "cuda":
|
| 157 |
torch.cuda.empty_cache()
|
| 158 |
|
| 159 |
-
return save_paths
|
| 160 |
|
| 161 |
|
| 162 |
# Gradio input and output components
|
|
@@ -171,9 +193,13 @@ variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()),
|
|
| 171 |
gr_interface = gr.Interface(
|
| 172 |
fn=generate_audio_gradio,
|
| 173 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
| 174 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
| 176 |
-
description="",
|
| 177 |
flagging_mode="never",
|
| 178 |
examples=[
|
| 179 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
|
@@ -193,7 +219,6 @@ gr_interface = gr.Interface(
|
|
| 193 |
)
|
| 194 |
|
| 195 |
if __name__ == "__main__":
|
| 196 |
-
|
| 197 |
ensure_models_downloaded()
|
| 198 |
load_model_cache()
|
| 199 |
gr_interface.queue(15).launch()
|
|
|
|
| 23 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 24 |
torch.backends.cudnn.allow_tf32 = True
|
| 25 |
import gc
|
| 26 |
+
import json
|
| 27 |
from datetime import datetime
|
| 28 |
from huggingface_hub import snapshot_download
|
| 29 |
import numpy as np
|
|
|
|
| 39 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 40 |
NUM_SAMPLE = 2
|
| 41 |
|
| 42 |
+
# 创建RLHF反馈数据目录
|
| 43 |
+
FEEDBACK_DIR = Path("./rlhf")
|
| 44 |
+
FEEDBACK_DIR.mkdir(exist_ok=True)
|
| 45 |
+
FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
|
| 46 |
+
|
| 47 |
# Global model cache to avoid reloading
|
| 48 |
MODEL_CACHE = {}
|
| 49 |
FEATURE_UTILS_CACHE = {}
|
|
|
|
| 86 |
).to(device, torch.bfloat16).eval()
|
| 87 |
FEATURE_UTILS_CACHE['default'] = feature_utils
|
| 88 |
|
| 89 |
+
def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""):
|
| 90 |
+
feedback_data = {
|
| 91 |
+
"timestamp": datetime.now().isoformat(),
|
| 92 |
+
"prompt": prompt,
|
| 93 |
+
"audio1_path": audio1_path,
|
| 94 |
+
"audio2_path": audio2_path,
|
| 95 |
+
"preference": preference, # "audio1", "audio2", "equal", "both_bad"
|
| 96 |
+
"additional_comment": additional_comment
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
|
| 100 |
+
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
|
| 101 |
+
|
| 102 |
+
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
| 103 |
+
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
| 104 |
+
|
| 105 |
|
| 106 |
@spaces.GPU(duration=60)
|
| 107 |
@torch.inference_mode()
|
|
|
|
| 119 |
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
|
| 120 |
|
| 121 |
net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
|
| 122 |
+
|
| 123 |
model = all_model_cfg[variant]
|
| 124 |
seq_cfg = model.seq_cfg
|
| 125 |
seq_cfg.duration = duration
|
|
|
|
| 164 |
|
| 165 |
audio = fade_out(audio, seq_cfg.sampling_rate)
|
| 166 |
|
| 167 |
+
safe_prompt = (
|
| 168 |
+
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
| 169 |
+
.rstrip()
|
| 170 |
+
.replace(" ", "_")[:50]
|
| 171 |
+
)
|
| 172 |
+
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 173 |
+
filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
|
| 174 |
+
save_path = OUTPUT_DIR / filename
|
| 175 |
+
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
|
| 176 |
+
log.info(f"Audio saved to {save_path}")
|
| 177 |
+
save_paths.append(str(save_path))
|
| 178 |
if device == "cuda":
|
| 179 |
torch.cuda.empty_cache()
|
| 180 |
|
| 181 |
+
return save_paths[0], save_paths[1], prompt
|
| 182 |
|
| 183 |
|
| 184 |
# Gradio input and output components
|
|
|
|
| 193 |
gr_interface = gr.Interface(
|
| 194 |
fn=generate_audio_gradio,
|
| 195 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
| 196 |
+
outputs=[
|
| 197 |
+
gr.Audio(label="🎵 Audio Sample 1"),
|
| 198 |
+
gr.Audio(label="🎵 Audio Sample 2"),
|
| 199 |
+
gr.Textbox(label="Prompt Used", interactive=False)
|
| 200 |
+
],
|
| 201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
| 202 |
+
description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
|
| 203 |
flagging_mode="never",
|
| 204 |
examples=[
|
| 205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
|
|
|
| 219 |
)
|
| 220 |
|
| 221 |
if __name__ == "__main__":
|
|
|
|
| 222 |
ensure_models_downloaded()
|
| 223 |
load_model_cache()
|
| 224 |
gr_interface.queue(15).launch()
|
feedback_collector.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
简单的反馈收集工具
|
| 4 |
+
在MeanAudio生成音频后,运行此脚本收集用户偏好
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
# 设置反馈目录
|
| 15 |
+
FEEDBACK_DIR = Path("./rlhf_feedback")
|
| 16 |
+
FEEDBACK_DIR.mkdir(exist_ok=True)
|
| 17 |
+
FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
|
| 18 |
+
|
| 19 |
+
def save_feedback(audio1_path, audio2_path, prompt, preference, comment=""):
|
| 20 |
+
"""保存反馈数据"""
|
| 21 |
+
feedback_data = {
|
| 22 |
+
"timestamp": datetime.now().isoformat(),
|
| 23 |
+
"prompt": prompt,
|
| 24 |
+
"audio1_path": audio1_path,
|
| 25 |
+
"audio2_path": audio2_path,
|
| 26 |
+
"preference": preference,
|
| 27 |
+
"additional_comment": comment
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
|
| 31 |
+
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
|
| 32 |
+
|
| 33 |
+
return f"✅ 反馈已保存!偏好: {preference}"
|
| 34 |
+
|
| 35 |
+
def create_feedback_interface():
|
| 36 |
+
"""创建反馈收集界面"""
|
| 37 |
+
|
| 38 |
+
with gr.Blocks(title="MeanAudio 反馈收集器") as demo:
|
| 39 |
+
gr.Markdown("# MeanAudio 反馈收集器")
|
| 40 |
+
gr.Markdown("*请输入生成的音频文件路径和提示词,然后选择您的偏好*")
|
| 41 |
+
|
| 42 |
+
with gr.Row():
|
| 43 |
+
with gr.Column():
|
| 44 |
+
prompt_input = gr.Textbox(
|
| 45 |
+
label="提示词",
|
| 46 |
+
placeholder="输入用于生成音频的提示词..."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
audio1_path = gr.Textbox(
|
| 50 |
+
label="音频文件1路径",
|
| 51 |
+
placeholder="./output/gradio/prompt_timestamp_0.flac"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
audio2_path = gr.Textbox(
|
| 55 |
+
label="音频文件2路径",
|
| 56 |
+
placeholder="./output/gradio/prompt_timestamp_1.flac"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
with gr.Column():
|
| 60 |
+
# 显示音频
|
| 61 |
+
audio1_player = gr.Audio(label="音频1")
|
| 62 |
+
audio2_player = gr.Audio(label="音频2")
|
| 63 |
+
|
| 64 |
+
load_btn = gr.Button("🔄 加载音频文件")
|
| 65 |
+
|
| 66 |
+
# 反馈区域
|
| 67 |
+
gr.Markdown("---")
|
| 68 |
+
gr.Markdown("### 请选择您的偏好")
|
| 69 |
+
|
| 70 |
+
preference = gr.Radio(
|
| 71 |
+
choices=[
|
| 72 |
+
("音频1更好", "audio1"),
|
| 73 |
+
("音频2更好", "audio2"),
|
| 74 |
+
("两者质量相等", "equal"),
|
| 75 |
+
("两者都不好", "both_bad")
|
| 76 |
+
],
|
| 77 |
+
label="哪个音频更好?"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
comment = gr.Textbox(
|
| 81 |
+
label="额外评论 (可选)",
|
| 82 |
+
placeholder="关于音频质量的具体反馈...",
|
| 83 |
+
lines=3
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
submit_btn = gr.Button("📝 提交反馈", variant="primary")
|
| 87 |
+
|
| 88 |
+
result = gr.Textbox(label="结果", interactive=False)
|
| 89 |
+
|
| 90 |
+
# 事件处理
|
| 91 |
+
def load_audio_files(path1, path2):
|
| 92 |
+
"""加载音频文件用于播放"""
|
| 93 |
+
audio1 = path1 if os.path.exists(path1) else None
|
| 94 |
+
audio2 = path2 if os.path.exists(path2) else None
|
| 95 |
+
return audio1, audio2
|
| 96 |
+
|
| 97 |
+
load_btn.click(
|
| 98 |
+
fn=load_audio_files,
|
| 99 |
+
inputs=[audio1_path, audio2_path],
|
| 100 |
+
outputs=[audio1_player, audio2_player]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
submit_btn.click(
|
| 104 |
+
fn=save_feedback,
|
| 105 |
+
inputs=[audio1_path, audio2_path, prompt_input, preference, comment],
|
| 106 |
+
outputs=[result]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 使用说明
|
| 110 |
+
gr.Markdown("---")
|
| 111 |
+
gr.Markdown("""
|
| 112 |
+
### 使用说明
|
| 113 |
+
1. 先运行 MeanAudio 生成两个音频文件
|
| 114 |
+
2. 将生成的音频文件路径复制到上面的输入框中
|
| 115 |
+
3. 点击"加载音频文件"来播放音频
|
| 116 |
+
4. 选择您的偏好并提交反馈
|
| 117 |
+
5. 反馈数据将保存到 `./rlhf_feedback/user_preferences.jsonl`
|
| 118 |
+
6. 使用 `python analyze_feedback.py` 分析收集的反馈数据
|
| 119 |
+
""")
|
| 120 |
+
|
| 121 |
+
return demo
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
demo = create_feedback_interface()
|
| 125 |
+
print("启动反馈收集界面...")
|
| 126 |
+
print(f"反馈数据将保存到: {FEEDBACK_FILE}")
|
| 127 |
+
demo.launch(server_name="127.0.0.1", server_port=7861)
|