DreamRenderer / app.py
HBDing's picture
优化app.py文件,移除演示模式相关代码,简化图像生成逻辑,确保在GPU模式下正常初始化DreamRenderer管道。同时更新README.md中的sdk版本至5.31.0,并添加硬件配置说明。
007662f
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageDraw
import json
import warnings
from typing import Optional
from dream_renderer import DreamRendererPipeline
DEMO_MODE = False
warnings.filterwarnings("ignore")
# 全局变量
pipeline = None
current_bbox_data = []
def create_demo_image(prompt: str, bbox_data: list, width: int = 512, height: int = 512):
"""创建演示图像"""
# 创建一个简单的演示图像
image = Image.new('RGB', (width, height), color='lightblue')
draw = ImageDraw.Draw(image)
# 绘制背景文字
try:
# 尝试绘制提示词
draw.text((10, 10), f"演示模式: {prompt[:50]}", fill='darkblue')
draw.text((10, 30), f"边界框数量: {len(bbox_data)}", fill='darkblue')
# 绘制边界框
for i, bbox in enumerate(bbox_data):
x = int(bbox['x'] * width)
y = int(bbox['y'] * height)
w = int(bbox['width'] * width)
h = int(bbox['height'] * height)
# 绘制边界框
color = f"hsl({i * 60}, 70%, 50%)"
# 简单的颜色映射
colors = ['red', 'green', 'blue', 'yellow', 'purple', 'orange']
bbox_color = colors[i % len(colors)]
draw.rectangle([x, y, x+w, y+h], outline=bbox_color, width=2)
draw.text((x+5, y+5), bbox.get('label', f'区域{i+1}'), fill=bbox_color)
except Exception as e:
draw.text((10, 50), f"绘制错误: {str(e)}", fill='red')
return image
@spaces.GPU
def initialize_pipeline():
"""初始化DreamRenderer管道"""
global pipeline
try:
if pipeline is None:
pipeline = DreamRendererPipeline()
# 预加载模型以节省时间
success = pipeline.load_model()
if success:
return "✅ DreamRenderer管道已成功初始化并加载模型!"
else:
return "⚠️ DreamRenderer管道已初始化,但模型加载失败。将使用演示模式。"
else:
return "✅ DreamRenderer管道已经初始化完成!"
except Exception as e:
return f"❌ 初始化失败: {str(e)}"
@spaces.GPU
def generate_image_with_bbox(prompt: str, negative_prompt: str,
num_inference_steps: int, guidance_scale: float,
width: int, height: int, seed: int, use_seed: bool):
"""使用边界框生成图像"""
global pipeline, current_bbox_data
if pipeline is None:
return None, "❌ 请先初始化DreamRenderer管道!"
if not prompt.strip():
return None, "❌ 请输入提示词!"
try:
# 设置种子
actual_seed = seed if use_seed else None
# 生成图像
image = pipeline.generate_image(
prompt=prompt,
bbox_data=current_bbox_data,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
seed=actual_seed
)
info = f"✅ 图像生成成功!\n"
info += f"🔸 使用边界框: {len(current_bbox_data)}个\n"
info += f"🔸 推理步数: {num_inference_steps}\n"
info += f"🔸 引导强度: {guidance_scale}\n"
info += f"🔸 图像尺寸: {width}×{height}\n"
if actual_seed is not None:
info += f"🔸 随机种子: {actual_seed}"
return image, info
except Exception as e:
return None, f"❌ 生成图像时出错: {str(e)}"
def load_bbox_component():
"""加载边界框绘制组件"""
try:
with open('bbox_component.html', 'r', encoding='utf-8') as f:
content = f.read()
return content
except FileNotFoundError:
# 返回简化的HTML内容
return """
<div style="text-align: center; padding: 20px; border: 2px dashed #ccc; border-radius: 8px;">
<p>边界框组件加载失败</p>
<p>请检查 bbox_component.html 文件是否存在</p>
</div>
"""
def update_bbox_data(bbox_json: str):
"""更新边界框数据显示"""
global current_bbox_data
try:
if not bbox_json or bbox_json.strip() == "":
current_bbox_data = []
return "📦 暂无边界框数据\n\n💡 提示:在画布上拖拽鼠标绘制边界框", ""
bbox_data = json.loads(bbox_json)
current_bbox_data = bbox_data # 重要:更新全局变量
if not bbox_data:
current_bbox_data = []
return "📦 暂无边界框数据\n\n💡 提示:在画布上拖拽鼠标绘制边界框", ""
info_lines = [
f"📦 边界框数据 ({len(bbox_data)} 个)",
"=" * 40,
""
]
# 生成边界框编辑界面HTML
edit_html_lines = [
'<div style="max-height: 400px; overflow-y: auto; padding: 10px; border: 1px solid #ddd; border-radius: 8px; background: #f9f9f9;">',
'<h4 style="color: #333; margin-top: 0;">🎯 边界框描述编辑</h4>'
]
for i, bbox in enumerate(bbox_data, 1):
x = bbox.get('x', 0)
y = bbox.get('y', 0)
width = bbox.get('width', 0)
height = bbox.get('height', 0)
label = bbox.get('label', f'区域{i}')
prompt = bbox.get('prompt', '') # 获取已有的提示词
info_lines.extend([
f"🎯 边界框 {i}:",
f" 📍 位置: ({x:.3f}, {y:.3f})",
f" 📏 大小: {width:.3f} × {height:.3f}",
f" 🏷️ 标签: {label}",
f" 💬 描述: {prompt or '(请在下方输入描述)'}",
""
])
# 为每个边界框生成编辑界面
color = f"hsl({(i-1) * 60}, 70%, 50%)"
edit_html_lines.extend([
f'<div style="margin: 15px 0; padding: 15px; border-left: 4px solid {color}; background: white; border-radius: 8px;">',
f' <div style="display: flex; align-items: center; margin-bottom: 10px;">',
f' <div style="width: 20px; height: 20px; background: {color}; border-radius: 4px; margin-right: 10px;"></div>',
f' <strong style="color: #333;">边界框 {i} - {label}</strong>',
f' <span style="margin-left: auto; font-size: 0.9em; color: #666;">({x:.2f}, {y:.2f}) {width:.2f}×{height:.2f}</span>',
f' </div>',
f' <div style="margin-bottom: 8px;">',
f' <label style="display: block; font-weight: bold; color: #555; margin-bottom: 5px;">🏷️ 区域标签:</label>',
f' <input type="text" id="bbox_label_{i-1}" value="{label}" placeholder="为这个区域命名..." ',
f' style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 14px;" ',
f' onchange="updateBboxField({i-1}, \'label\', this.value)">',
f' </div>',
f' <div style="margin-bottom: 8px;">',
f' <label style="display: block; font-weight: bold; color: #555; margin-bottom: 5px;">💬 详细描述:</label>',
f' <textarea id="bbox_prompt_{i-1}" placeholder="描述这个区域应该生成什么内容..." ',
f' style="width: 100%; height: 80px; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 14px; resize: vertical;" ',
f' onchange="updateBboxField({i-1}, \'prompt\', this.value)">{prompt}</textarea>',
f' </div>',
f' <div style="text-align: right;">',
f' <button onclick="deleteBbox({i-1})" style="background: #ff4757; color: white; border: none; padding: 6px 12px; border-radius: 4px; cursor: pointer; font-size: 12px;">',
f' 🗑️ 删除此框',
f' </button>',
f' </div>',
f'</div>'
])
edit_html_lines.extend([
'<div style="margin-top: 20px; padding: 15px; background: #e8f5e8; border-radius: 8px;">',
' <div style="display: flex; justify-content: space-between; align-items: center;">',
' <div>',
f' <strong style="color: #2d5a2d;">✅ 共 {len(bbox_data)} 个边界框</strong>',
' <br><small style="color: #5a5a5a;">修改描述后会自动保存</small>',
' </div>',
' <button onclick="clearAllBboxes()" style="background: #ff6b6b; color: white; border: none; padding: 8px 16px; border-radius: 6px; cursor: pointer;">',
' 🗑️ 清空所有',
' </button>',
' </div>',
'</div>',
'</div>'
])
info_lines.extend([
"💡 使用说明:",
"• 在画布上拖拽绘制新的边界框",
"• 在右侧为每个框输入具体描述",
"• 每个框可以有不同的生成内容",
"• 描述越详细,生成效果越好"
])
print(f"DEBUG: 边界框数据已更新: {len(current_bbox_data)}个") # 调试信息
return "\n".join(info_lines), "\n".join(edit_html_lines)
except json.JSONDecodeError:
current_bbox_data = []
return f"❌ 边界框数据格式错误\n\n原始数据: {bbox_json[:200]}...", ""
except Exception as e:
current_bbox_data = []
return f"❌ 处理边界框数据时出错: {str(e)}", ""
def create_interface():
"""创建Gradio界面"""
# 自定义CSS
css = """
.main-container {
max-width: 1400px;
margin: 0 auto;
}
.bbox-container {
border: 2px solid #e1e5e9;
border-radius: 12px;
padding: 20px;
margin: 15px 0;
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
}
.generate-btn {
background: linear-gradient(45deg, #FF6B6B, #4ECDC4);
border: none;
border-radius: 25px;
padding: 15px 35px;
color: white;
font-weight: bold;
font-size: 18px;
box-shadow: 0 4px 15px rgba(0,0,0,0.2);
transition: all 0.3s ease;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(0,0,0,0.3);
}
.init-btn {
background: linear-gradient(45deg, #667eea, #764ba2);
border: none;
border-radius: 20px;
color: white;
font-weight: bold;
padding: 12px 25px;
}
"""
with gr.Blocks(css=css, title="DreamRenderer - Multi-Instance Control", theme=gr.themes.Soft()) as demo:
# 根据模式显示不同的标题
if DEMO_MODE:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1 style="background: linear-gradient(45deg, #FF6B6B, #4ECDC4); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3em; margin-bottom: 10px;">
🎨 DreamRenderer (演示模式)
</h1>
<h2 style="color: #666; margin-bottom: 20px;">Multi-Instance Attribute Control</h2>
<div style="background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 15px; margin: 20px auto; max-width: 800px;">
<p style="margin: 0; color: #856404; font-size: 1.1em;">
⚠️ <strong>当前运行在演示模式下</strong><br>
由于缺少实际的AI模型,系统将生成简单的演示图像来展示界面功能。<br>
您仍然可以测试边界框绘制和参数设置功能。
</p>
</div>
</div>
""")
else:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1 style="background: linear-gradient(45deg, #FF6B6B, #4ECDC4); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3em; margin-bottom: 10px;">
🎨 DreamRenderer
</h1>
<h2 style="color: #666; margin-bottom: 20px;">Multi-Instance Attribute Control</h2>
<p style="font-size: 1.2em; color: #888; max-width: 800px; margin: 0 auto;">
基于ZeroGPU的高质量多实例属性控制文本到图像生成工具
</p>
</div>
""")
# 使用说明
with gr.Accordion("📖 使用说明", open=False):
if DEMO_MODE:
gr.Markdown("""
### 🚀 演示模式说明:
1. **功能测试**: 点击"初始化"按钮启动演示模式
2. **绘制区域**: 在画布上拖拽鼠标绘制边界框
3. **添加描述**: 为每个边界框输入描述文本
4. **设置参数**: 调整生成参数(用于演示)
5. **生成图像**: 输入主提示词并点击生成演示图像
### ⚠️ 演示模式限制:
- 🎯 **界面功能**: 所有界面功能都可以正常使用
- 🖼️ **图像生成**: 生成的是简单的演示图像,非AI生成
- 📦 **边界框**: 边界框绘制和编辑功能完全正常
- 🔧 **参数调节**: 参数设置功能正常,但不影响实际生成
### 📌 完整功能需要:
- 安装完整的AI模型(dream_renderer模块)
- 配置ZeroGPU环境
""")
else:
gr.Markdown("""
### 🚀 快速开始:
1. **初始化**: 点击"初始化DreamRenderer"按钮加载模型
2. **绘制区域**: 在画布上拖拽鼠标绘制边界框
3. **添加描述**: 为每个边界框输入描述文本
4. **设置参数**: 调整生成参数(可选)
5. **生成图像**: 输入主提示词并点击生成
### ✨ 功能特点:
- 🎯 **精确控制**: 通过边界框精确控制每个实例的位置和属性
- 🚀 **ZeroGPU加速**: 利用Hugging Face的ZeroGPU实现快速推理
- 🎨 **高质量生成**: 基于FLUX模型的高质量图像生成
- 🔧 **灵活参数**: 丰富的参数调节选项
""")
with gr.Row():
# 左侧:边界框绘制和控制
with gr.Column(scale=1):
# 初始化部分
with gr.Group():
gr.Markdown("### 🚀 模型初始化")
if DEMO_MODE:
init_btn = gr.Button("🚀 启动演示模式", variant="primary")
else:
init_btn = gr.Button("🚀 初始化DreamRenderer", variant="primary")
init_status = gr.Textbox(label="初始化状态", interactive=False, lines=2)
# 边界框绘制区域
with gr.Group():
gr.Markdown("### 📦 边界框绘制")
gr.HTML("""
<div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); padding: 10px; border-radius: 8px; margin: 10px 0;">
<p style="margin: 0; color: #1976d2;"><strong>步骤1:</strong> 在画布上拖拽鼠标绘制边界框</p>
<p style="margin: 5px 0 0 0; color: #1976d2;"><strong>步骤2:</strong> 在右侧为每个框输入详细描述</p>
</div>
""")
bbox_component = gr.HTML(load_bbox_component())
# 隐藏的输入框用于接收边界框数据
bbox_data_input = gr.Textbox(visible=False, elem_id="bbox_data")
bbox_info = gr.Textbox(label="📦 边界框信息", interactive=False, lines=8, placeholder="边界框信息将在这里显示...")
# 右侧:边界框编辑和生成参数
with gr.Column(scale=1):
# 边界框编辑区域
with gr.Group():
gr.Markdown("### ✏️ 边界框描述编辑")
bbox_editor = gr.HTML(
value="<div style='text-align: center; padding: 40px; color: #666;'>绘制边界框后,编辑界面将出现在这里</div>",
elem_id="bbox_editor"
)
# 提示词设置
with gr.Group():
gr.Markdown("### 📝 提示词设置")
prompt = gr.Textbox(
label="主提示词",
placeholder="描述你想要生成的整体场景...",
lines=3,
value="a beautiful landscape"
)
negative_prompt = gr.Textbox(
label="负向提示词",
placeholder="描述你不想看到的内容...",
lines=2,
value="blurry, low quality, distorted"
)
# 生成参数
with gr.Group():
gr.Markdown("### ⚙️ 生成参数")
with gr.Row():
num_steps = gr.Slider(
minimum=1, maximum=100, value=20, step=1,
label="推理步数",
info="更多步数通常能获得更好的质量"
)
guidance_scale = gr.Slider(
minimum=1.0, maximum=30.0, value=7.5, step=0.5,
label="引导强度",
info="控制对提示词的遵循程度"
)
with gr.Row():
width = gr.Slider(
minimum=256, maximum=1024, value=512, step=64,
label="宽度"
)
height = gr.Slider(
minimum=256, maximum=1024, value=512, step=64,
label="高度"
)
with gr.Row():
use_seed = gr.Checkbox(label="使用固定种子", value=False)
seed = gr.Number(label="随机种子", value=42, precision=0)
# 生成按钮
generate_btn = gr.Button(
"🎨 生成图像",
variant="primary",
size="lg"
)
# 结果显示
with gr.Group():
gr.Markdown("### 🖼️ 生成结果")
output_image = gr.Image(label="生成的图像", height=500, show_label=False)
generation_info = gr.Textbox(label="生成信息", interactive=False, lines=6)
# 示例和更多选项
with gr.Accordion("🎯 示例和技巧", open=False):
gr.Markdown("""
### 📌 提示词示例:
- **风景场景**: "a serene mountain landscape with a lake, golden hour lighting"
- **城市场景**: "modern city skyline at sunset, futuristic architecture"
- **人物场景**: "a group of people in a park, casual clothing, natural lighting"
### 🎨 使用技巧:
1. **边界框大小**: 合适的边界框大小有助于更好的控制效果
2. **描述精确性**: 为每个区域提供具体而精确的描述
3. **参数调节**: 较高的引导强度可以提高对提示词的遵循度
4. **种子控制**: 使用固定种子可以获得可重复的结果
""")
# 事件绑定
init_btn.click(
fn=initialize_pipeline,
outputs=init_status,
show_progress=True
)
bbox_data_input.change(
fn=update_bbox_data,
inputs=bbox_data_input,
outputs=[bbox_info, bbox_editor]
)
generate_btn.click(
fn=generate_image_with_bbox,
inputs=[prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, use_seed],
outputs=[output_image, generation_info],
show_progress=True
)
# JavaScript代码用于处理边界框数据通信
demo.load(None, None, None, js="""
function() {
console.log('DreamRenderer界面已加载');
// 给DOM一些时间加载
setTimeout(function() {
try {
// 全局变量
let isDrawing = false;
let startX, startY, currentRect;
let bboxes = [];
const canvas = document.getElementById('bboxCanvas');
if (!canvas) {
console.error('画布元素未找到');
return;
}
const ctx = canvas.getContext('2d');
// 清除之前的监听器并添加新的
canvas.replaceWith(canvas.cloneNode(true));
const newCanvas = document.getElementById('bboxCanvas');
const newCtx = newCanvas.getContext('2d');
// 重新设置样式
newCanvas.style.display = 'block';
newCanvas.style.border = '2px solid #4ECDC4';
newCanvas.style.backgroundColor = 'white';
newCanvas.style.cursor = 'crosshair';
newCanvas.style.borderRadius = '8px';
// 边界框编辑函数
window.updateBboxField = function(index, field, value) {
if (index >= 0 && index < bboxes.length) {
bboxes[index][field] = value;
console.log(`更新边界框 ${index} 的 ${field}:`, value);
updateBboxData();
}
};
window.deleteBbox = function(index) {
if (index >= 0 && index < bboxes.length) {
bboxes.splice(index, 1);
console.log(`删除边界框 ${index}`);
redrawCanvas();
updateBboxData();
}
};
window.clearAllBboxes = function() {
bboxes = [];
console.log('清空所有边界框');
redrawCanvas();
updateBboxData();
};
// 重绘画布
function redrawCanvas() {
newCtx.clearRect(0, 0, newCanvas.width, newCanvas.height);
bboxes.forEach((bbox, index) => {
newCtx.strokeStyle = `hsl(${index * 60}, 70%, 50%)`;
newCtx.lineWidth = 2;
newCtx.strokeRect(bbox.x, bbox.y, bbox.width, bbox.height);
// 绘制标签
newCtx.fillStyle = `hsl(${index * 60}, 70%, 50%)`;
newCtx.font = '12px Arial';
newCtx.fillText(bbox.label || `区域${index + 1}`, bbox.x + 5, bbox.y - 5);
});
}
// 更新边界框数据
function updateBboxData() {
const relativeBboxes = bboxes.map(b => ({
x: b.x / newCanvas.width,
y: b.y / newCanvas.height,
width: b.width / newCanvas.width,
height: b.height / newCanvas.height,
label: b.label || '',
prompt: b.prompt || ''
}));
const dataString = JSON.stringify(relativeBboxes);
console.log('📤 更新数据:', relativeBboxes.length, '个边界框');
const textarea = document.querySelector('#bbox_data textarea');
if (textarea) {
textarea.value = dataString;
textarea.dispatchEvent(new Event('input', { bubbles: true }));
}
}
// 添加绘制事件监听器
newCanvas.addEventListener('mousedown', function(e) {
isDrawing = true;
startX = e.offsetX;
startY = e.offsetY;
console.log('🎯 开始绘制:', startX, startY);
});
newCanvas.addEventListener('mousemove', function(e) {
if (!isDrawing) return;
const currentX = e.offsetX;
const currentY = e.offsetY;
// 清除画布并重绘所有边界框
redrawCanvas();
// 绘制当前正在绘制的框
newCtx.strokeStyle = '#007bff';
newCtx.lineWidth = 2;
newCtx.setLineDash([5, 5]);
const width = currentX - startX;
const height = currentY - startY;
newCtx.strokeRect(startX, startY, width, height);
newCtx.setLineDash([]);
});
newCanvas.addEventListener('mouseup', function(e) {
if (!isDrawing) return;
isDrawing = false;
const endX = e.offsetX;
const endY = e.offsetY;
const width = endX - startX;
const height = endY - startY;
// 只有当框足够大时才添加
if (Math.abs(width) > 10 && Math.abs(height) > 10) {
const bbox = {
x: Math.min(startX, endX),
y: Math.min(startY, endY),
width: Math.abs(width),
height: Math.abs(height),
label: `区域${bboxes.length + 1}`,
prompt: ''
};
bboxes.push(bbox);
console.log('✅ 添加边界框:', bbox);
redrawCanvas();
updateBboxData();
}
console.log('🎯 绘制结束,当前边界框数量:', bboxes.length);
});
console.log('🚀 DreamRenderer边界框功能已就绪!');
} catch (error) {
console.error('初始化边界框功能时出错:', error);
}
}, 1000); // 延迟1秒执行
}
""")
return demo
if __name__ == "__main__":
# 创建并启动应用
demo = create_interface()
demo.launch(
show_error=True
)