add rlhf
Browse files
app.py
CHANGED
|
@@ -102,39 +102,6 @@ def save_preference_feedback(prompt, audio1_path, audio2_path, preference, addit
|
|
| 102 |
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
| 103 |
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
| 104 |
|
| 105 |
-
def save_preference_feedback_from_flag(input_text, duration, cfg_strength, num_steps, variant,
|
| 106 |
-
audio1_path, audio2_path, prompt_used, preference, comment):
|
| 107 |
-
"""处理Gradio flagging回调的反馈保存"""
|
| 108 |
-
try:
|
| 109 |
-
if not preference:
|
| 110 |
-
print("⚠️ 用户没有选择偏好")
|
| 111 |
-
return
|
| 112 |
-
|
| 113 |
-
feedback_data = {
|
| 114 |
-
"timestamp": datetime.now().isoformat(),
|
| 115 |
-
"prompt": prompt_used or input_text,
|
| 116 |
-
"audio1_path": audio1_path,
|
| 117 |
-
"audio2_path": audio2_path,
|
| 118 |
-
"preference": preference,
|
| 119 |
-
"additional_comment": comment or "",
|
| 120 |
-
"generation_params": {
|
| 121 |
-
"duration": duration,
|
| 122 |
-
"cfg_strength": cfg_strength,
|
| 123 |
-
"num_steps": num_steps,
|
| 124 |
-
"variant": variant
|
| 125 |
-
}
|
| 126 |
-
}
|
| 127 |
-
|
| 128 |
-
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
|
| 129 |
-
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
|
| 130 |
-
|
| 131 |
-
log.info(f"✅ 反馈已保存: {preference} - {prompt_used[:50]}...")
|
| 132 |
-
print(f"✅ 用户反馈已保存到: {FEEDBACK_FILE}")
|
| 133 |
-
|
| 134 |
-
except Exception as e:
|
| 135 |
-
log.error(f"保存反馈时出错: {e}")
|
| 136 |
-
print(f"❌ 保存反馈时出错: {e}")
|
| 137 |
-
|
| 138 |
|
| 139 |
@spaces.GPU(duration=60)
|
| 140 |
@torch.inference_mode()
|
|
@@ -227,31 +194,13 @@ gr_interface = gr.Interface(
|
|
| 227 |
fn=generate_audio_gradio,
|
| 228 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
| 229 |
outputs=[
|
| 230 |
-
gr.Audio(label="🎵 Audio Sample 1"),
|
| 231 |
-
gr.Audio(label="🎵 Audio Sample 2"),
|
| 232 |
gr.Textbox(label="Prompt Used", interactive=False)
|
| 233 |
],
|
| 234 |
-
additional_inputs=[
|
| 235 |
-
gr.Radio(
|
| 236 |
-
choices=[
|
| 237 |
-
("🎵 Audio 1 更好", "audio1"),
|
| 238 |
-
("🎵 Audio 2 更好", "audio2"),
|
| 239 |
-
("😊 两者都很好", "equal"),
|
| 240 |
-
("😔 两者都不好", "both_bad")
|
| 241 |
-
],
|
| 242 |
-
label="🤔 请选择您更喜欢的音频:",
|
| 243 |
-
value=None
|
| 244 |
-
),
|
| 245 |
-
gr.Textbox(
|
| 246 |
-
label="💭 评论 (可选)",
|
| 247 |
-
placeholder="您对音频质量的具体反馈...",
|
| 248 |
-
lines=2
|
| 249 |
-
)
|
| 250 |
-
],
|
| 251 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
| 252 |
-
description="🎯 **RLHF数据收集**: 现在生成2
|
| 253 |
-
flagging_mode="
|
| 254 |
-
flagging_callback=lambda *args: save_preference_feedback_from_flag(*args),
|
| 255 |
examples=[
|
| 256 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
| 257 |
["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
|
|
@@ -266,13 +215,82 @@ gr_interface = gr.Interface(
|
|
| 266 |
['doorbell ding once followed by footsteps gradually getting louder and a door is opened ', 10, 3, 1, "meanaudio_s_full"],
|
| 267 |
["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
|
| 268 |
],
|
| 269 |
-
cache_examples="lazy",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
if __name__ == "__main__":
|
| 273 |
ensure_models_downloaded()
|
| 274 |
load_model_cache()
|
| 275 |
-
gr_interface.queue(15).launch()
|
| 276 |
|
| 277 |
# theme = gr.themes.Soft(
|
| 278 |
# primary_hue="blue",
|
|
|
|
| 102 |
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
| 103 |
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
@spaces.GPU(duration=60)
|
| 107 |
@torch.inference_mode()
|
|
|
|
| 194 |
fn=generate_audio_gradio,
|
| 195 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
| 196 |
outputs=[
|
| 197 |
+
gr.Audio(label="🎵 Audio Sample 1", type="filepath"),
|
| 198 |
+
gr.Audio(label="🎵 Audio Sample 2", type="filepath"),
|
| 199 |
gr.Textbox(label="Prompt Used", interactive=False)
|
| 200 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
| 202 |
+
description="🎯 **RLHF数据收集**: 现在生成2个音频样本!生成后请在下方选择偏好并提交。",
|
| 203 |
+
flagging_mode="never",
|
|
|
|
| 204 |
examples=[
|
| 205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
| 206 |
["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
|
|
|
|
| 215 |
['doorbell ding once followed by footsteps gradually getting louder and a door is opened ', 10, 3, 1, "meanaudio_s_full"],
|
| 216 |
["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
|
| 217 |
],
|
| 218 |
+
cache_examples="lazy",
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# ==== Preference collection UI (RLHF) ====
|
| 222 |
+
|
| 223 |
+
# 允许用户在两段音频之间选择偏好,并补充备注
|
| 224 |
+
with gr.Blocks() as pref_block:
|
| 225 |
+
gr.Markdown("## 🧠 RLHF 偏好标注")
|
| 226 |
+
gr.Markdown("生成完成后,请在下方选择您更喜欢的音频(或都不好/差不多),并可附加简短备注。点“提交偏好”即可写入 `./rlhf/user_preferences.jsonl`。")
|
| 227 |
+
|
| 228 |
+
# 这里复用上面 Interface 的输出:我们需要拿到两段音频的文件路径与使用的 prompt
|
| 229 |
+
# 为了连接这两个“界面”,再放一组可粘连的输入组件:
|
| 230 |
+
with gr.Row():
|
| 231 |
+
gen_audio1_path = gr.Textbox(label="Audio 1 路径(自动填充)", interactive=False)
|
| 232 |
+
gen_audio2_path = gr.Textbox(label="Audio 2 路径(自动填充)", interactive=False)
|
| 233 |
+
prompt_used = gr.Textbox(label="Prompt(自动填充)", interactive=False)
|
| 234 |
+
|
| 235 |
+
# 偏好选项与备注
|
| 236 |
+
pref_choice = gr.Radio(
|
| 237 |
+
["audio1", "audio2", "equal", "both_bad"],
|
| 238 |
+
value="audio1",
|
| 239 |
+
label="你更偏好哪个?",
|
| 240 |
+
info="equal=差不多; both_bad=都不好"
|
| 241 |
+
)
|
| 242 |
+
pref_comment = gr.Textbox(label="可选备注(例如:哪一段更贴合描述、是否有噪声/破音等)", lines=2)
|
| 243 |
+
|
| 244 |
+
submit_btn = gr.Button("✅ 提交偏好")
|
| 245 |
+
submit_status = gr.Markdown()
|
| 246 |
+
|
| 247 |
+
# 小工具:读取当前标注条目数
|
| 248 |
+
def _count_feedback():
|
| 249 |
+
try:
|
| 250 |
+
with open(FEEDBACK_FILE, "r", encoding="utf-8") as f:
|
| 251 |
+
return sum(1 for _ in f)
|
| 252 |
+
except FileNotFoundError:
|
| 253 |
+
return 0
|
| 254 |
+
|
| 255 |
+
refresh_btn = gr.Button("📈 刷新统计")
|
| 256 |
+
count_box = gr.Markdown()
|
| 257 |
+
|
| 258 |
+
def submit_preference_ui(a1, a2, p, pref, cmt):
|
| 259 |
+
if not a1 or not a2:
|
| 260 |
+
return "❗请先在上面的生成器里生成两段音频。"
|
| 261 |
+
# 写入 jsonl
|
| 262 |
+
msg = save_preference_feedback(p, a1, a2, pref, cmt)
|
| 263 |
+
return msg
|
| 264 |
+
|
| 265 |
+
def refresh_count_ui():
|
| 266 |
+
n = _count_feedback()
|
| 267 |
+
return f"当前已收集 **{n}** 条偏好样本。"
|
| 268 |
+
|
| 269 |
+
submit_btn.click(
|
| 270 |
+
fn=submit_preference_ui,
|
| 271 |
+
inputs=[gen_audio1_path, gen_audio2_path, prompt_used, pref_choice, pref_comment],
|
| 272 |
+
outputs=submit_status
|
| 273 |
)
|
| 274 |
+
refresh_btn.click(fn=refresh_count_ui, outputs=count_box)
|
| 275 |
+
|
| 276 |
+
# —— 把 Interface 的输出“联动”到偏好区:当用户生成完成后,自动把路径和 prompt 填入偏好区输入框 ——
|
| 277 |
+
def _passthrough(a1, a2, p):
|
| 278 |
+
# 直接把接口输出透传给下方偏好区
|
| 279 |
+
return a1, a2, p
|
| 280 |
+
|
| 281 |
+
# 用 Events 把 Interface 的输出连到 pref_block 的三个文本框
|
| 282 |
+
gr_interface.submit(
|
| 283 |
+
fn=_passthrough,
|
| 284 |
+
inputs=gr_interface.outputs, # [Audio1(filepath), Audio2(filepath), PromptUsed]
|
| 285 |
+
outputs=[gen_audio1_path, gen_audio2_path, prompt_used],
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
| 291 |
ensure_models_downloaded()
|
| 292 |
load_model_cache()
|
| 293 |
+
gr_interface.queue(15).launch(share=False, show_api=False)
|
| 294 |
|
| 295 |
# theme = gr.themes.Soft(
|
| 296 |
# primary_hue="blue",
|