Longxiang-ai commited on
Commit
0274afd
·
0 Parent(s):

Initial commit: DreamRenderer with Zero GPU support

Browse files
Files changed (6) hide show
  1. .gitignore +49 -0
  2. README.md +57 -0
  3. app.py +565 -0
  4. bbox_component.html +353 -0
  5. dream_renderer.py +312 -0
  6. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ pip-wheel-metadata/
20
+ share/python-wheels/
21
+ *.egg-info/
22
+ .installed.cfg
23
+ *.egg
24
+ MANIFEST
25
+
26
+ # PyTorch
27
+ *.pth
28
+ *.pt
29
+
30
+ # IDE
31
+ .vscode/
32
+ .idea/
33
+ *.swp
34
+ *.swo
35
+
36
+ # OS
37
+ .DS_Store
38
+ Thumbs.db
39
+
40
+ # Logs
41
+ *.log
42
+
43
+ # Temporary files
44
+ *.tmp
45
+ *.temp
46
+
47
+ # Test files (不需要上传到生产环境)
48
+ test_*.py
49
+ *_test.py
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DreamRenderer
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ hardware: zero-gpu-medium
12
+ ---
13
+
14
+ # DreamRenderer: Multi-Instance Attribute Control 🎨
15
+
16
+ A powerful Gradio implementation of DreamRenderer for precise multi-instance attribute control in text-to-image generation, powered by **Zero GPU** for fast inference.
17
+
18
+ ## ✨ Features
19
+
20
+ - 🖼️ **Interactive Bounding Box Drawing**: 直观的画布界面,轻松绘制多个区域
21
+ - 🎯 **Multi-Instance Attribute Control**: 为每个区域设置独特的生成内容
22
+ - ⚡ **Zero GPU Acceleration**: 利用Hugging Face的Zero GPU获得极速推理
23
+ - 🚀 **FLUX Model Support**: 支持最新的FLUX扩散模型
24
+ - 🎨 **Real-time Preview**: 实时预览边界框和生成参数
25
+
26
+ ## 🚀 Quick Start
27
+
28
+ 1. 在画布上拖拽鼠标绘制边界框
29
+ 2. 为每个边界框添加详细的描述
30
+ 3. 设置全局提示词和生成参数
31
+ 4. 点击生成按钮,享受AI创作的乐趣!
32
+
33
+ ## 🛠️ Technical Details
34
+
35
+ - **Frontend**: Gradio 4.44.0
36
+ - **Backend**: PyTorch + Diffusers
37
+ - **Model**: FLUX-based diffusion model
38
+ - **Acceleration**: Hugging Face Zero GPU
39
+ - **Memory**: Optimized for GPU memory efficiency
40
+
41
+ ## 💡 Usage Tips
42
+
43
+ - 描述越详细,生成效果越好
44
+ - 可以为不同区域设置完全不同的内容
45
+ - 调整推理步数和引导强度来控制生成质量
46
+ - 使用种子功能可以获得可重现的结果
47
+
48
+ ## 🎯 Perfect for
49
+
50
+ - 复杂场景构图
51
+ - 多物体图像生成
52
+ - 精确的空间布局控制
53
+ - 创意设计和艺术创作
54
+
55
+ ---
56
+
57
+ **注意**: 此应用使用Zero GPU资源,首次加载可能需要几秒钟时间进行模型初始化。
app.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+ import json
7
+ import base64
8
+ import io
9
+ from typing import List, Dict, Tuple, Optional
10
+ import warnings
11
+ from dream_renderer import DreamRendererPipeline
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+ # 全局变量
16
+ pipeline = None
17
+ current_bbox_data = []
18
+
19
+ @spaces.GPU
20
+ def initialize_pipeline():
21
+ """初始化DreamRenderer管道"""
22
+ global pipeline
23
+ try:
24
+ if pipeline is None:
25
+ pipeline = DreamRendererPipeline()
26
+ # 预加载模型以节省时间
27
+ success = pipeline.load_model()
28
+ if success:
29
+ return "✅ DreamRenderer管道已成功初始化并加载模型!"
30
+ else:
31
+ return "⚠️ DreamRenderer管道已初始化,但模型加载失败。将使用演示模式。"
32
+ else:
33
+ return "✅ DreamRenderer管道已经初始化完成!"
34
+ except Exception as e:
35
+ return f"❌ 初始化失败: {str(e)}"
36
+
37
+ def load_bbox_component():
38
+ """加载边界框绘制组件"""
39
+ try:
40
+ with open('bbox_component_fixed.html', 'r', encoding='utf-8') as f:
41
+ return f.read()
42
+ except FileNotFoundError:
43
+ # 如果修复版本不存在,使用原版本
44
+ with open('bbox_component.html', 'r', encoding='utf-8') as f:
45
+ return f.read()
46
+
47
+ def update_bbox_data(bbox_json: str):
48
+ """更新边界框数据显示"""
49
+ global current_bbox_data
50
+
51
+ try:
52
+ if not bbox_json or bbox_json.strip() == "":
53
+ current_bbox_data = []
54
+ return "📦 暂无边界框数据\n\n💡 提示:在画布上拖拽鼠标绘制边界框", ""
55
+
56
+ bbox_data = json.loads(bbox_json)
57
+ current_bbox_data = bbox_data # 重要:更新全局变量
58
+
59
+ if not bbox_data:
60
+ current_bbox_data = []
61
+ return "📦 暂无边界框数据\n\n💡 提示:在画布上拖拽鼠标绘制边界框", ""
62
+
63
+ info_lines = [
64
+ f"📦 边界框数据 ({len(bbox_data)} 个)",
65
+ "=" * 40,
66
+ ""
67
+ ]
68
+
69
+ # 生成边界框编辑界面HTML
70
+ edit_html_lines = [
71
+ '<div style="max-height: 400px; overflow-y: auto; padding: 10px; border: 1px solid #ddd; border-radius: 8px; background: #f9f9f9;">',
72
+ '<h4 style="color: #333; margin-top: 0;">🎯 边界框描述编辑</h4>'
73
+ ]
74
+
75
+ for i, bbox in enumerate(bbox_data, 1):
76
+ x = bbox.get('x', 0)
77
+ y = bbox.get('y', 0)
78
+ width = bbox.get('width', 0)
79
+ height = bbox.get('height', 0)
80
+ label = bbox.get('label', f'区域{i}')
81
+ prompt = bbox.get('prompt', '') # 获取已有的提示词
82
+
83
+ info_lines.extend([
84
+ f"🎯 边界框 {i}:",
85
+ f" 📍 位置: ({x:.3f}, {y:.3f})",
86
+ f" 📏 大小: {width:.3f} × {height:.3f}",
87
+ f" 🏷️ 标签: {label}",
88
+ f" 💬 描述: {prompt or '(请在下方输入描述)'}",
89
+ ""
90
+ ])
91
+
92
+ # 为每个边界框生成编辑界面
93
+ color = f"hsl({(i-1) * 60}, 70%, 50%)"
94
+ edit_html_lines.extend([
95
+ f'<div style="margin: 15px 0; padding: 15px; border-left: 4px solid {color}; background: white; border-radius: 8px;">',
96
+ f' <div style="display: flex; align-items: center; margin-bottom: 10px;">',
97
+ f' <div style="width: 20px; height: 20px; background: {color}; border-radius: 4px; margin-right: 10px;"></div>',
98
+ f' <strong style="color: #333;">边界框 {i} - {label}</strong>',
99
+ f' <span style="margin-left: auto; font-size: 0.9em; color: #666;">({x:.2f}, {y:.2f}) {width:.2f}×{height:.2f}</span>',
100
+ f' </div>',
101
+ f' <div style="margin-bottom: 8px;">',
102
+ f' <label style="display: block; font-weight: bold; color: #555; margin-bottom: 5px;">🏷️ 区域标签:</label>',
103
+ f' <input type="text" id="bbox_label_{i-1}" value="{label}" placeholder="为这个区域命名..." ',
104
+ f' style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 14px;" ',
105
+ f' onchange="updateBboxField({i-1}, \'label\', this.value)">',
106
+ f' </div>',
107
+ f' <div style="margin-bottom: 8px;">',
108
+ f' <label style="display: block; font-weight: bold; color: #555; margin-bottom: 5px;">💬 详细描述:</label>',
109
+ f' <textarea id="bbox_prompt_{i-1}" placeholder="描述这个区域应该生成什么内容..." ',
110
+ f' style="width: 100%; height: 80px; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 14px; resize: vertical;" ',
111
+ f' onchange="updateBboxField({i-1}, \'prompt\', this.value)">{prompt}</textarea>',
112
+ f' </div>',
113
+ f' <div style="text-align: right;">',
114
+ f' <button onclick="deleteBbox({i-1})" style="background: #ff4757; color: white; border: none; padding: 6px 12px; border-radius: 4px; cursor: pointer; font-size: 12px;">',
115
+ f' 🗑️ 删除此框',
116
+ f' </button>',
117
+ f' </div>',
118
+ f'</div>'
119
+ ])
120
+
121
+ edit_html_lines.extend([
122
+ '<div style="margin-top: 20px; padding: 15px; background: #e8f5e8; border-radius: 8px;">',
123
+ ' <div style="display: flex; justify-content: space-between; align-items: center;">',
124
+ ' <div>',
125
+ f' <strong style="color: #2d5a2d;">✅ 共 {len(bbox_data)} 个边界框</strong>',
126
+ ' <br><small style="color: #5a5a5a;">修改描述后会自动保存</small>',
127
+ ' </div>',
128
+ ' <button onclick="clearAllBboxes()" style="background: #ff6b6b; color: white; border: none; padding: 8px 16px; border-radius: 6px; cursor: pointer;">',
129
+ ' 🗑️ 清空所有',
130
+ ' </button>',
131
+ ' </div>',
132
+ '</div>',
133
+ '</div>'
134
+ ])
135
+
136
+ info_lines.extend([
137
+ "💡 使用说明:",
138
+ "• 在画布上拖拽绘制新的边界框",
139
+ "• 在右侧为每个框输入具体描述",
140
+ "• 每个框可以有不同的生成内容",
141
+ "• 描述越详细,生成效果越好"
142
+ ])
143
+
144
+ print(f"DEBUG: 边界框数据已更新: {len(current_bbox_data)}个") # 调试信息
145
+ return "\n".join(info_lines), "\n".join(edit_html_lines)
146
+
147
+ except json.JSONDecodeError:
148
+ current_bbox_data = []
149
+ return f"❌ 边界框数据格式错误\n\n原始数据: {bbox_json[:200]}...", ""
150
+ except Exception as e:
151
+ current_bbox_data = []
152
+ return f"❌ 处理边界框数据时出错: {str(e)}", ""
153
+
154
+ @spaces.GPU
155
+ def generate_image_with_bbox(prompt: str, negative_prompt: str,
156
+ num_inference_steps: int, guidance_scale: float,
157
+ width: int, height: int, seed: int, use_seed: bool):
158
+ """使用边界框生成图像"""
159
+ global pipeline, current_bbox_data
160
+
161
+ if pipeline is None:
162
+ return None, "❌ 请先初始化DreamRenderer管道!"
163
+
164
+ if not prompt.strip():
165
+ return None, "❌ 请输入提示词!"
166
+
167
+ try:
168
+ # 设置种子
169
+ actual_seed = seed if use_seed else None
170
+
171
+ # 生成图像
172
+ image = pipeline.generate_image(
173
+ prompt=prompt,
174
+ bbox_data=current_bbox_data,
175
+ negative_prompt=negative_prompt,
176
+ num_inference_steps=num_inference_steps,
177
+ guidance_scale=guidance_scale,
178
+ width=width,
179
+ height=height,
180
+ seed=actual_seed
181
+ )
182
+
183
+ info = f"✅ 图像生成成功!\n"
184
+ info += f"🔸 使用边界框: {len(current_bbox_data)}个\n"
185
+ info += f"🔸 推理步数: {num_inference_steps}\n"
186
+ info += f"🔸 引导强度: {guidance_scale}\n"
187
+ info += f"🔸 图像尺寸: {width}×{height}\n"
188
+ if actual_seed is not None:
189
+ info += f"🔸 随机种子: {actual_seed}"
190
+
191
+ return image, info
192
+ except Exception as e:
193
+ return None, f"❌ 生成图像时出错: {str(e)}"
194
+
195
+ def create_interface():
196
+ """创建Gradio界面"""
197
+
198
+ # 自定义CSS
199
+ css = """
200
+ .main-container {
201
+ max-width: 1400px;
202
+ margin: 0 auto;
203
+ }
204
+ .bbox-container {
205
+ border: 2px solid #e1e5e9;
206
+ border-radius: 12px;
207
+ padding: 20px;
208
+ margin: 15px 0;
209
+ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
210
+ }
211
+ .generate-btn {
212
+ background: linear-gradient(45deg, #FF6B6B, #4ECDC4);
213
+ border: none;
214
+ border-radius: 25px;
215
+ padding: 15px 35px;
216
+ color: white;
217
+ font-weight: bold;
218
+ font-size: 18px;
219
+ box-shadow: 0 4px 15px rgba(0,0,0,0.2);
220
+ transition: all 0.3s ease;
221
+ }
222
+ .generate-btn:hover {
223
+ transform: translateY(-2px);
224
+ box-shadow: 0 6px 20px rgba(0,0,0,0.3);
225
+ }
226
+ .init-btn {
227
+ background: linear-gradient(45deg, #667eea, #764ba2);
228
+ border: none;
229
+ border-radius: 20px;
230
+ color: white;
231
+ font-weight: bold;
232
+ padding: 12px 25px;
233
+ }
234
+ .parameter-group {
235
+ background: #f8f9fa;
236
+ border-radius: 10px;
237
+ padding: 15px;
238
+ margin: 10px 0;
239
+ }
240
+ """
241
+
242
+ with gr.Blocks(css=css, title="DreamRenderer - Multi-Instance Control", theme=gr.themes.Soft()) as demo:
243
+ gr.HTML("""
244
+ <div style="text-align: center; padding: 20px;">
245
+ <h1 style="background: linear-gradient(45deg, #FF6B6B, #4ECDC4); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3em; margin-bottom: 10px;">
246
+ 🎨 DreamRenderer
247
+ </h1>
248
+ <h2 style="color: #666; margin-bottom: 20px;">Multi-Instance Attribute Control</h2>
249
+ <p style="font-size: 1.2em; color: #888; max-width: 800px; margin: 0 auto;">
250
+ 基于ZeroGPU的高质量多实例属性控制文本到图像生成工具
251
+ </p>
252
+ </div>
253
+ """)
254
+
255
+ # 使用说明
256
+ with gr.Accordion("📖 使用说明", open=False):
257
+ gr.Markdown("""
258
+ ### 🚀 快速开始:
259
+ 1. **初始化**: 点击"初始化DreamRenderer"按钮加载模型
260
+ 2. **绘制区域**: 在画布上拖拽鼠标绘制边界框
261
+ 3. **添加描述**: 为每个边界框输入描述文本
262
+ 4. **设置参数**: 调整生成参数(可选)
263
+ 5. **生成图像**: 输入主提示词并点击生成
264
+
265
+ ### ✨ 功能特点:
266
+ - 🎯 **精确控制**: 通过边界框精确控制每个实例的位置和属性
267
+ - 🚀 **ZeroGPU加速**: 利用Hugging Face的ZeroGPU实现快速推理
268
+ - 🎨 **高质量生成**: 基于FLUX模型的高质量图像生成
269
+ - 🔧 **灵活参数**: 丰富的参数调节选项
270
+ """)
271
+
272
+ with gr.Row():
273
+ # 左侧:边界框绘制和控制
274
+ with gr.Column(scale=1):
275
+ # 初始化部分
276
+ with gr.Group():
277
+ gr.Markdown("### 🚀 模型初始化")
278
+ init_btn = gr.Button("🚀 初始化DreamRenderer", variant="primary", elem_classes=["init-btn"])
279
+ init_status = gr.Textbox(label="初始化状态", interactive=False, lines=2)
280
+
281
+ # 边界框绘制区域
282
+ with gr.Group():
283
+ gr.Markdown("### 📦 边界框绘制")
284
+ gr.HTML("""
285
+ <div style="background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); padding: 10px; border-radius: 8px; margin: 10px 0;">
286
+ <p style="margin: 0; color: #1976d2;"><strong>步骤1:</strong> 在画布上拖拽鼠标绘制边界框</p>
287
+ <p style="margin: 5px 0 0 0; color: #1976d2;"><strong>步骤2:</strong> 在右侧为每个框输入详细描述</p>
288
+ </div>
289
+ """)
290
+ bbox_component = gr.HTML(load_bbox_component(), elem_classes=["bbox-container"])
291
+
292
+ # 隐藏的输入框用于接收边界框数据
293
+ bbox_data_input = gr.Textbox(visible=False, elem_id="bbox_data")
294
+ bbox_info = gr.Textbox(label="📦 边界框信息", interactive=False, lines=8, placeholder="边界框信息将在这里显示...")
295
+
296
+ # 右侧:边界框编辑和生成参数
297
+ with gr.Column(scale=1):
298
+ # 边界框编辑区域
299
+ with gr.Group():
300
+ gr.Markdown("### ✏️ 边界框描述编辑")
301
+ bbox_editor = gr.HTML(
302
+ value="<div style='text-align: center; padding: 40px; color: #666;'>绘制边界框后,编辑界面将出现在这里</div>",
303
+ elem_id="bbox_editor"
304
+ )
305
+
306
+ # 提示词设置
307
+ with gr.Group():
308
+ gr.Markdown("### 📝 提示词设置")
309
+ prompt = gr.Textbox(
310
+ label="主提示词",
311
+ placeholder="描述你想要生成的整体场景...",
312
+ lines=3,
313
+ value="a beautiful landscape"
314
+ )
315
+ negative_prompt = gr.Textbox(
316
+ label="负向提示词",
317
+ placeholder="描述你不想看到的内容...",
318
+ lines=2,
319
+ value="blurry, low quality, distorted"
320
+ )
321
+
322
+ # 生成参数
323
+ with gr.Group():
324
+ gr.Markdown("### ⚙️ 生成参数")
325
+
326
+ with gr.Row():
327
+ num_steps = gr.Slider(
328
+ minimum=1, maximum=100, value=20, step=1,
329
+ label="推理步数",
330
+ info="更多步数通常能获得更好的质量"
331
+ )
332
+ guidance_scale = gr.Slider(
333
+ minimum=1.0, maximum=30.0, value=7.5, step=0.5,
334
+ label="引导强度",
335
+ info="控制对提示词的遵循程度"
336
+ )
337
+
338
+ with gr.Row():
339
+ width = gr.Slider(
340
+ minimum=256, maximum=1024, value=512, step=64,
341
+ label="宽度"
342
+ )
343
+ height = gr.Slider(
344
+ minimum=256, maximum=1024, value=512, step=64,
345
+ label="高度"
346
+ )
347
+
348
+ with gr.Row():
349
+ use_seed = gr.Checkbox(label="使用固定种子", value=False)
350
+ seed = gr.Number(label="随机种子", value=42, precision=0)
351
+
352
+ # 生成按钮
353
+ generate_btn = gr.Button(
354
+ "🎨 生成图像",
355
+ variant="primary",
356
+ elem_classes=["generate-btn"],
357
+ size="lg"
358
+ )
359
+
360
+ # 结果显示
361
+ with gr.Group():
362
+ gr.Markdown("### 🖼️ 生成结果")
363
+ output_image = gr.Image(label="生成的图像", height=500, show_label=False)
364
+ generation_info = gr.Textbox(label="生成信息", interactive=False, lines=6)
365
+
366
+ # 示例和更多选项
367
+ with gr.Accordion("🎯 示例和技巧", open=False):
368
+ gr.Markdown("""
369
+ ### 📌 提示词示例:
370
+ - **风景场景**: "a serene mountain landscape with a lake, golden hour lighting"
371
+ - **城市场景**: "modern city skyline at sunset, futuristic architecture"
372
+ - **人物场景**: "a group of people in a park, casual clothing, natural lighting"
373
+
374
+ ### 🎨 使用技巧:
375
+ 1. **边界框大小**: 合适的边界框大小有助于更好的控制效果
376
+ 2. **描述精确性**: 为每个区域提供具体而精确的描述
377
+ 3. **参数调节**: 较高的引导强度可以提高对提示词的遵循度
378
+ 4. **种子控制**: 使用固定种子可以获得可重复的结果
379
+ """)
380
+
381
+ # 事件绑定
382
+ init_btn.click(
383
+ fn=initialize_pipeline,
384
+ outputs=init_status,
385
+ show_progress=True
386
+ )
387
+
388
+ bbox_data_input.change(
389
+ fn=update_bbox_data,
390
+ inputs=bbox_data_input,
391
+ outputs=[bbox_info, bbox_editor]
392
+ )
393
+
394
+ generate_btn.click(
395
+ fn=generate_image_with_bbox,
396
+ inputs=[prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, use_seed],
397
+ outputs=[output_image, generation_info],
398
+ show_progress=True
399
+ )
400
+
401
+ # JavaScript代码用于处理边界框数据通信
402
+ demo.load(None, None, None, js="""
403
+ function() {
404
+ console.log('DreamRenderer界面已加载');
405
+
406
+ // 全局变量
407
+ let isDrawing = false;
408
+ let startX, startY, currentRect;
409
+ let bboxes = [];
410
+
411
+ const canvas = document.getElementById('bboxCanvas');
412
+ if (!canvas) {
413
+ console.error('画布元素未找到');
414
+ return;
415
+ }
416
+
417
+ const ctx = canvas.getContext('2d');
418
+
419
+ // 清除之前的监听器并添加新的
420
+ canvas.replaceWith(canvas.cloneNode(true));
421
+ const newCanvas = document.getElementById('bboxCanvas');
422
+ const newCtx = newCanvas.getContext('2d');
423
+
424
+ // 重新设置样式
425
+ newCanvas.style.display = 'block';
426
+ newCanvas.style.border = '2px solid #4ECDC4';
427
+ newCanvas.style.backgroundColor = 'white';
428
+ newCanvas.style.cursor = 'crosshair';
429
+ newCanvas.style.borderRadius = '8px';
430
+
431
+ // 边界框编辑函数
432
+ window.updateBboxField = function(index, field, value) {
433
+ if (index >= 0 && index < bboxes.length) {
434
+ bboxes[index][field] = value;
435
+ console.log(`更新边界框 ${index} 的 ${field}:`, value);
436
+ updateBboxData();
437
+ }
438
+ };
439
+
440
+ window.deleteBbox = function(index) {
441
+ if (index >= 0 && index < bboxes.length) {
442
+ bboxes.splice(index, 1);
443
+ console.log(`删除边界框 ${index}`);
444
+ redrawCanvas();
445
+ updateBboxData();
446
+ }
447
+ };
448
+
449
+ window.clearAllBboxes = function() {
450
+ bboxes = [];
451
+ console.log('清空所有边界框');
452
+ redrawCanvas();
453
+ updateBboxData();
454
+ };
455
+
456
+ // 重绘画布
457
+ function redrawCanvas() {
458
+ newCtx.clearRect(0, 0, newCanvas.width, newCanvas.height);
459
+ bboxes.forEach((bbox, index) => {
460
+ newCtx.strokeStyle = `hsl(${index * 60}, 70%, 50%)`;
461
+ newCtx.lineWidth = 2;
462
+ newCtx.strokeRect(bbox.x, bbox.y, bbox.width, bbox.height);
463
+
464
+ // 绘制标签
465
+ newCtx.fillStyle = `hsl(${index * 60}, 70%, 50%)`;
466
+ newCtx.font = '12px Arial';
467
+ newCtx.fillText(bbox.label || `区域${index + 1}`, bbox.x + 5, bbox.y - 5);
468
+ });
469
+ }
470
+
471
+ // 更新边界框数据
472
+ function updateBboxData() {
473
+ const relativeBboxes = bboxes.map(b => ({
474
+ x: b.x / newCanvas.width,
475
+ y: b.y / newCanvas.height,
476
+ width: b.width / newCanvas.width,
477
+ height: b.height / newCanvas.height,
478
+ label: b.label || '',
479
+ prompt: b.prompt || ''
480
+ }));
481
+
482
+ const dataString = JSON.stringify(relativeBboxes);
483
+ console.log('📤 更新数据:', relativeBboxes.length, '个边界框');
484
+
485
+ const textarea = document.querySelector('#bbox_data textarea');
486
+ if (textarea) {
487
+ textarea.value = dataString;
488
+ textarea.dispatchEvent(new Event('input', { bubbles: true }));
489
+ }
490
+ }
491
+
492
+ // 添加绘制事件监听器
493
+ newCanvas.addEventListener('mousedown', function(e) {
494
+ isDrawing = true;
495
+ startX = e.offsetX;
496
+ startY = e.offsetY;
497
+ console.log('🎯 开始绘制:', startX, startY);
498
+ });
499
+
500
+ newCanvas.addEventListener('mousemove', function(e) {
501
+ if (!isDrawing) return;
502
+
503
+ const currentX = e.offsetX;
504
+ const currentY = e.offsetY;
505
+
506
+ // 清除画布并重绘所有边界框
507
+ redrawCanvas();
508
+
509
+ // 绘制当前正在绘制的框
510
+ newCtx.strokeStyle = '#007bff';
511
+ newCtx.lineWidth = 2;
512
+ newCtx.setLineDash([5, 5]);
513
+ const width = currentX - startX;
514
+ const height = currentY - startY;
515
+ newCtx.strokeRect(startX, startY, width, height);
516
+ newCtx.setLineDash([]);
517
+ });
518
+
519
+ newCanvas.addEventListener('mouseup', function(e) {
520
+ if (!isDrawing) return;
521
+ isDrawing = false;
522
+
523
+ const endX = e.offsetX;
524
+ const endY = e.offsetY;
525
+ const width = endX - startX;
526
+ const height = endY - startY;
527
+
528
+ // 只有当框足够大时才添加
529
+ if (Math.abs(width) > 10 && Math.abs(height) > 10) {
530
+ const bbox = {
531
+ x: Math.min(startX, endX),
532
+ y: Math.min(startY, endY),
533
+ width: Math.abs(width),
534
+ height: Math.abs(height),
535
+ label: `区域${bboxes.length + 1}`,
536
+ prompt: ''
537
+ };
538
+
539
+ bboxes.push(bbox);
540
+ console.log('✅ 添加边界框:', bbox);
541
+
542
+ redrawCanvas();
543
+ updateBboxData();
544
+ }
545
+
546
+ console.log('🎯 绘制结束,当前边界框数量:', bboxes.length);
547
+ });
548
+
549
+ console.log('🚀 DreamRenderer边界框功能已就绪!');
550
+ }
551
+ """)
552
+
553
+ return demo
554
+
555
+ if __name__ == "__main__":
556
+ # 创建并启动应用
557
+ demo = create_interface()
558
+ demo.launch(
559
+ server_name="0.0.0.0",
560
+ server_port=7860,
561
+ share=True,
562
+ show_api=False,
563
+ favicon_path=None,
564
+ show_error=True
565
+ )
bbox_component.html ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <style>
5
+ .canvas-container {
6
+ position: relative;
7
+ display: inline-block;
8
+ border: 2px solid #ddd;
9
+ border-radius: 8px;
10
+ overflow: hidden;
11
+ background-color: #f8f9fa;
12
+ }
13
+ .bbox-canvas {
14
+ cursor: crosshair;
15
+ display: block;
16
+ background-color: white;
17
+ }
18
+ .bbox-list {
19
+ margin-top: 10px;
20
+ padding: 10px;
21
+ background-color: #f8f9fa;
22
+ border-radius: 5px;
23
+ max-height: 200px;
24
+ overflow-y: auto;
25
+ }
26
+ .bbox-item {
27
+ display: flex;
28
+ justify-content: space-between;
29
+ align-items: center;
30
+ padding: 5px;
31
+ margin: 2px 0;
32
+ background-color: white;
33
+ border-radius: 3px;
34
+ border-left: 4px solid #007bff;
35
+ }
36
+ .bbox-input {
37
+ width: 150px;
38
+ padding: 2px 5px;
39
+ border: 1px solid #ddd;
40
+ border-radius: 3px;
41
+ }
42
+ .delete-btn {
43
+ background-color: #dc3545;
44
+ color: white;
45
+ border: none;
46
+ padding: 2px 8px;
47
+ border-radius: 3px;
48
+ cursor: pointer;
49
+ font-size: 12px;
50
+ }
51
+ .delete-btn:hover {
52
+ background-color: #c82333;
53
+ }
54
+ .clear-btn {
55
+ background-color: #6c757d;
56
+ color: white;
57
+ border: none;
58
+ padding: 5px 15px;
59
+ border-radius: 3px;
60
+ cursor: pointer;
61
+ margin-top: 10px;
62
+ }
63
+ .clear-btn:hover {
64
+ background-color: #5a6268;
65
+ }
66
+ .color-indicator {
67
+ width: 20px;
68
+ height: 20px;
69
+ border-radius: 3px;
70
+ border: 2px solid white;
71
+ box-shadow: 0 0 3px rgba(0,0,0,0.3);
72
+ }
73
+ .info-text {
74
+ margin: 10px 0;
75
+ padding: 8px;
76
+ background-color: #e3f2fd;
77
+ border-radius: 4px;
78
+ font-size: 14px;
79
+ color: #1976d2;
80
+ }
81
+ .debug-info {
82
+ margin: 10px 0;
83
+ padding: 8px;
84
+ background-color: #fff3cd;
85
+ border-radius: 4px;
86
+ font-size: 12px;
87
+ color: #856404;
88
+ font-family: monospace;
89
+ }
90
+ </style>
91
+ </head>
92
+ <body>
93
+ <div class="info-text">
94
+ 💡 拖拽鼠标在画布上绘制边界框,然后为每个框添加描述
95
+ </div>
96
+
97
+ <div class="canvas-container">
98
+ <canvas id="bboxCanvas" class="bbox-canvas" width="512" height="512"></canvas>
99
+ </div>
100
+
101
+ <script>
102
+ console.log('边界框组件已加载');
103
+
104
+ const canvas = document.getElementById('bboxCanvas');
105
+ const ctx = canvas.getContext('2d');
106
+ const debugInfo = document.getElementById('debugInfo');
107
+
108
+ let isDrawing = false;
109
+ let startX, startY;
110
+ let boxes = [];
111
+ let currentBox = null;
112
+
113
+ const colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD', '#98D8C8', '#F7DC6F'];
114
+ let colorIndex = 0;
115
+
116
+ // 添加调试日志函数
117
+ function log(message) {
118
+ console.log(message);
119
+ debugInfo.textContent = `调试: ${message}`;
120
+ }
121
+
122
+ // 初始化画布
123
+ function initCanvas() {
124
+ ctx.fillStyle = '#ffffff';
125
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
126
+ ctx.strokeStyle = '#ddd';
127
+ ctx.lineWidth = 1;
128
+ ctx.strokeRect(0, 0, canvas.width, canvas.height);
129
+ log('画布已初始化');
130
+ }
131
+
132
+ // 事件监听器
133
+ canvas.addEventListener('mousedown', startDrawing);
134
+ canvas.addEventListener('mousemove', draw);
135
+ canvas.addEventListener('mouseup', stopDrawing);
136
+ canvas.addEventListener('mouseleave', stopDrawing); // 添加鼠标离开事件
137
+
138
+ function startDrawing(e) {
139
+ isDrawing = true;
140
+ const rect = canvas.getBoundingClientRect();
141
+ startX = e.clientX - rect.left;
142
+ startY = e.clientY - rect.top;
143
+ log(`开始绘制: (${Math.round(startX)}, ${Math.round(startY)})`);
144
+ }
145
+
146
+ function draw(e) {
147
+ if (!isDrawing) return;
148
+
149
+ const rect = canvas.getBoundingClientRect();
150
+ const currentX = e.clientX - rect.left;
151
+ const currentY = e.clientY - rect.top;
152
+
153
+ redrawCanvas();
154
+
155
+ // 绘制当前正在绘制的框
156
+ ctx.strokeStyle = colors[colorIndex % colors.length];
157
+ ctx.lineWidth = 2;
158
+ ctx.setLineDash([5, 5]);
159
+ ctx.strokeRect(startX, startY, currentX - startX, currentY - startY);
160
+ ctx.setLineDash([]);
161
+ }
162
+
163
+ function stopDrawing(e) {
164
+ if (!isDrawing) return;
165
+ isDrawing = false;
166
+
167
+ const rect = canvas.getBoundingClientRect();
168
+ const endX = e.clientX - rect.left;
169
+ const endY = e.clientY - rect.top;
170
+
171
+ const width = Math.abs(endX - startX);
172
+ const height = Math.abs(endY - startY);
173
+
174
+ if (width > 10 && height > 10) {
175
+ const box = {
176
+ x: Math.min(startX, endX),
177
+ y: Math.min(startY, endY),
178
+ width: width,
179
+ height: height,
180
+ color: colors[colorIndex % colors.length],
181
+ label: '',
182
+ id: Date.now()
183
+ };
184
+
185
+ boxes.push(box);
186
+ colorIndex++;
187
+ addBoxToList(box);
188
+ redrawCanvas();
189
+ updateOutput();
190
+ log(`添加边界框: ${boxes.length}个`);
191
+ } else {
192
+ redrawCanvas();
193
+ log('边界框太小,已忽略');
194
+ }
195
+ }
196
+
197
+ function redrawCanvas() {
198
+ // 清除画布并重新绘制背景
199
+ ctx.fillStyle = '#ffffff';
200
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
201
+
202
+ // 绘制所有边界框
203
+ boxes.forEach((box, index) => {
204
+ ctx.strokeStyle = box.color;
205
+ ctx.lineWidth = 2;
206
+ ctx.setLineDash([]);
207
+ ctx.strokeRect(box.x, box.y, box.width, box.height);
208
+
209
+ // 绘制标签
210
+ if (box.label) {
211
+ ctx.fillStyle = box.color;
212
+ ctx.font = '14px Arial';
213
+ ctx.fillText(box.label, box.x, box.y - 5);
214
+ }
215
+
216
+ // 绘制索引号
217
+ ctx.fillStyle = box.color;
218
+ ctx.font = 'bold 12px Arial';
219
+ ctx.fillText(`${index + 1}`, box.x + 3, box.y + 15);
220
+ });
221
+ }
222
+
223
+ function addBoxToList(box) {
224
+ const bboxItems = document.getElementById('bboxItems');
225
+ const item = document.createElement('div');
226
+ item.className = 'bbox-item';
227
+ item.id = `bbox-item-${box.id}`;
228
+ item.innerHTML = `
229
+ <div style="display: flex; align-items: center; gap: 10px;">
230
+ <div class="color-indicator" style="background-color: ${box.color}"></div>
231
+ <input type="text" class="bbox-input" placeholder="输入描述..."
232
+ onchange="updateBoxLabel(${box.id}, this.value)"
233
+ oninput="updateBoxLabel(${box.id}, this.value)">
234
+ <span style="font-size: 12px; color: #666;">
235
+ (${Math.round(box.x)}, ${Math.round(box.y)}, ${Math.round(box.width)}, ${Math.round(box.height)})
236
+ </span>
237
+ </div>
238
+ <button class="delete-btn" onclick="deleteBox(${box.id})">删除</button>
239
+ `;
240
+ bboxItems.appendChild(item);
241
+ }
242
+
243
+ function updateBoxLabel(boxId, label) {
244
+ const box = boxes.find(b => b.id === boxId);
245
+ if (box) {
246
+ box.label = label;
247
+ redrawCanvas();
248
+ updateOutput();
249
+ log(`更新标签: ${label}`);
250
+ }
251
+ }
252
+
253
+ function deleteBox(boxId) {
254
+ const oldLength = boxes.length;
255
+ boxes = boxes.filter(b => b.id !== boxId);
256
+ redrawBboxList();
257
+ redrawCanvas();
258
+ updateOutput();
259
+ log(`删除边界框: ${oldLength} -> ${boxes.length}`);
260
+ }
261
+
262
+ function clearAllBoxes() {
263
+ boxes = [];
264
+ redrawBboxList();
265
+ redrawCanvas();
266
+ updateOutput();
267
+ log('清除所有边界框');
268
+ }
269
+
270
+ function redrawBboxList() {
271
+ const bboxItems = document.getElementById('bboxItems');
272
+ bboxItems.innerHTML = '';
273
+ boxes.forEach(box => addBoxToList(box));
274
+ }
275
+
276
+ function updateOutput() {
277
+ try {
278
+ // 将边界框数据传递给Gradio
279
+ const boxData = boxes.map(box => ({
280
+ x: box.x / canvas.width, // 归一化坐标
281
+ y: box.y / canvas.height,
282
+ width: box.width / canvas.width,
283
+ height: box.height / canvas.height,
284
+ label: box.label || ''
285
+ }));
286
+
287
+ const dataString = JSON.stringify(boxData);
288
+ log(`发送数据: ${boxData.length}个边界框`);
289
+
290
+ // 直接查找Gradio输入框(因为组件直接嵌入在页面中)
291
+ const bboxInput = document.querySelector('#bbox_data textarea');
292
+ if (bboxInput) {
293
+ bboxInput.value = dataString;
294
+
295
+ // 触发多种事件确保Gradio能检测到变化
296
+ bboxInput.dispatchEvent(new Event('input', { bubbles: true }));
297
+ bboxInput.dispatchEvent(new Event('change', { bubbles: true }));
298
+ bboxInput.dispatchEvent(new Event('blur', { bubbles: true }));
299
+
300
+ log('直接更新Gradio输入框成功');
301
+ } else {
302
+ // 如果直接查找失败,尝试延迟查找
303
+ setTimeout(() => {
304
+ const delayedBboxInput = document.querySelector('#bbox_data textarea') ||
305
+ document.querySelector('[data-testid="textbox"] textarea');
306
+ if (delayedBboxInput) {
307
+ delayedBboxInput.value = dataString;
308
+ delayedBboxInput.dispatchEvent(new Event('input', { bubbles: true }));
309
+ delayedBboxInput.dispatchEvent(new Event('change', { bubbles: true }));
310
+ log('延迟更新Gradio输入框成功');
311
+ } else {
312
+ log('未找到Gradio输入框');
313
+ }
314
+ }, 500);
315
+ }
316
+
317
+ // 同时触发自定义事件作为备用
318
+ document.dispatchEvent(new CustomEvent('bbox_data_update', {
319
+ detail: { data: boxData, dataString: dataString }
320
+ }));
321
+
322
+ } catch (error) {
323
+ log(`更新输出时出错: ${error.message}`);
324
+ console.error('updateOutput error:', error);
325
+ }
326
+ }
327
+
328
+ // 接收来自Gradio的图片更新
329
+ window.addEventListener('message', function(event) {
330
+ if (event.data && event.data.type === 'update_image') {
331
+ const img = new Image();
332
+ img.onload = function() {
333
+ canvas.width = img.width;
334
+ canvas.height = img.height;
335
+ ctx.drawImage(img, 0, 0);
336
+ redrawCanvas();
337
+ log('图片已更新');
338
+ };
339
+ img.src = event.data.imageUrl;
340
+ }
341
+ });
342
+
343
+ // 页面加载完成后初始化
344
+ window.addEventListener('load', function() {
345
+ initCanvas();
346
+ log('组件已就绪');
347
+ });
348
+
349
+ // 立即初始化
350
+ initCanvas();
351
+ </script>
352
+ </body>
353
+ </html>
dream_renderer.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DreamRenderer实现模块
3
+ """
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers import FluxPipeline
8
+ from PIL import Image, ImageDraw
9
+ import numpy as np
10
+ from typing import List, Dict, Optional, Tuple
11
+ import spaces
12
+
13
+ class DreamRendererPipeline:
14
+ """
15
+ DreamRenderer管道实现
16
+ """
17
+
18
+ def __init__(self, model_id: str = "black-forest-labs/FLUX.1-dev"):
19
+ """
20
+ 初始化DreamRenderer管道
21
+
22
+ Args:
23
+ model_id: 使用的模型ID
24
+ """
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ self.model_id = model_id
27
+ self.pipe = None
28
+ self.loaded = False
29
+
30
+ def load_model(self):
31
+ """加载FLUX模型"""
32
+ try:
33
+ print(f"正在加载模型: {self.model_id}")
34
+ self.pipe = FluxPipeline.from_pretrained(
35
+ self.model_id,
36
+ torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
37
+ use_safetensors=True
38
+ )
39
+ self.pipe = self.pipe.to(self.device)
40
+
41
+ # 启用内存高效的注意力机制
42
+ if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'):
43
+ self.pipe.enable_xformers_memory_efficient_attention()
44
+
45
+ self.loaded = True
46
+ print("模型加载完成!")
47
+ return True
48
+
49
+ except Exception as e:
50
+ print(f"模型加载失败: {str(e)}")
51
+ self.loaded = False
52
+ return False
53
+
54
+ def create_layout_mask(self, bbox_data: List[Dict], width: int, height: int) -> torch.Tensor:
55
+ """
56
+ 根据边界框数据创建布局掩码
57
+
58
+ Args:
59
+ bbox_data: 边界框数据列表
60
+ width: 图像宽度
61
+ height: 图像高度
62
+
63
+ Returns:
64
+ 布局掩码张量
65
+ """
66
+ mask = torch.zeros((height, width), dtype=torch.float32)
67
+
68
+ for i, bbox in enumerate(bbox_data):
69
+ x = int(bbox['x'] * width)
70
+ y = int(bbox['y'] * height)
71
+ w = int(bbox['width'] * width)
72
+ h = int(bbox['height'] * height)
73
+
74
+ # 在掩码中标记区域
75
+ mask[y:y+h, x:x+w] = i + 1
76
+
77
+ return mask
78
+
79
+ def create_attention_mask(self, bbox_data: List[Dict], width: int, height: int) -> List[torch.Tensor]:
80
+ """
81
+ 为每个实例创建注意力掩码
82
+
83
+ Args:
84
+ bbox_data: 边界框数据列表
85
+ width: 图像宽度
86
+ height: 图像高度
87
+
88
+ Returns:
89
+ 注意力掩码列表
90
+ """
91
+ masks = []
92
+
93
+ for bbox in bbox_data:
94
+ mask = torch.zeros((height, width), dtype=torch.float32)
95
+
96
+ x = int(bbox['x'] * width)
97
+ y = int(bbox['y'] * height)
98
+ w = int(bbox['width'] * width)
99
+ h = int(bbox['height'] * height)
100
+
101
+ # 创建软边界的掩码
102
+ mask[y:y+h, x:x+w] = 1.0
103
+
104
+ # 应用高斯模糊以创建软边界
105
+ if torch.cuda.is_available():
106
+ mask = mask.unsqueeze(0).unsqueeze(0).cuda()
107
+ mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1)
108
+ mask = mask.squeeze().cpu()
109
+
110
+ masks.append(mask)
111
+
112
+ return masks
113
+
114
+ def modify_attention_weights(self, attention_weights: torch.Tensor,
115
+ attention_masks: List[torch.Tensor],
116
+ current_token_idx: int) -> torch.Tensor:
117
+ """
118
+ 修改注意力权重以实现区域控制
119
+
120
+ Args:
121
+ attention_weights: 原始注意力权重
122
+ attention_masks: 注意力掩码列表
123
+ current_token_idx: 当前token索引
124
+
125
+ Returns:
126
+ 修改后的注意力权重
127
+ """
128
+ # 这里实现DreamRenderer的核心注意力修改逻辑
129
+ # 根据当前token和对应的区域掩码调整注意力权重
130
+
131
+ if current_token_idx < len(attention_masks):
132
+ mask = attention_masks[current_token_idx]
133
+
134
+ # 将掩码应用到注意力权重
135
+ if mask.device != attention_weights.device:
136
+ mask = mask.to(attention_weights.device)
137
+
138
+ # 增强对应区域的注意力
139
+ attention_weights = attention_weights * (1 + mask * 0.5)
140
+
141
+ return attention_weights
142
+
143
+ @spaces.GPU
144
+ def generate_image(self,
145
+ prompt: str,
146
+ bbox_data: List[Dict],
147
+ negative_prompt: str = "",
148
+ num_inference_steps: int = 20,
149
+ guidance_scale: float = 7.5,
150
+ width: int = 512,
151
+ height: int = 512,
152
+ seed: Optional[int] = None) -> Image.Image:
153
+ """
154
+ 生成图像的主要函数
155
+
156
+ Args:
157
+ prompt: 主提示词
158
+ bbox_data: 边界框数据
159
+ negative_prompt: 负向提示词
160
+ num_inference_steps: 推理步数
161
+ guidance_scale: 引导强度
162
+ width: 图像宽度
163
+ height: 图像高度
164
+ seed: 随机种子
165
+
166
+ Returns:
167
+ 生成的图像
168
+ """
169
+ if not self.loaded:
170
+ if not self.load_model():
171
+ # 如果模型加载失败,返回一个演示图像
172
+ return self._create_demo_image(prompt, bbox_data, width, height)
173
+
174
+ # 设置随机种子
175
+ if seed is not None:
176
+ generator = torch.Generator(device=self.device).manual_seed(seed)
177
+ else:
178
+ generator = None
179
+
180
+ try:
181
+ # 构建完整的提示词
182
+ full_prompt = self._build_full_prompt(prompt, bbox_data)
183
+
184
+ # 如果没有边界框数据,直接使用标准生成
185
+ if not bbox_data:
186
+ image = self.pipe(
187
+ prompt=full_prompt,
188
+ negative_prompt=negative_prompt,
189
+ num_inference_steps=num_inference_steps,
190
+ guidance_scale=guidance_scale,
191
+ width=width,
192
+ height=height,
193
+ generator=generator
194
+ ).images[0]
195
+ else:
196
+ # 使用DreamRenderer的区域控制逻辑
197
+ image = self._generate_with_bbox_control(
198
+ full_prompt, bbox_data, negative_prompt,
199
+ num_inference_steps, guidance_scale,
200
+ width, height, generator
201
+ )
202
+
203
+ return image
204
+
205
+ except Exception as e:
206
+ print(f"生成图像时出错: {str(e)}")
207
+ # 返回演示图像
208
+ return self._create_demo_image(prompt, bbox_data, width, height)
209
+
210
+ def _build_full_prompt(self, main_prompt: str, bbox_data: List[Dict]) -> str:
211
+ """构建包含区域描述的完整提示词"""
212
+ full_prompt = main_prompt
213
+
214
+ if bbox_data:
215
+ region_descriptions = []
216
+ for i, bbox in enumerate(bbox_data):
217
+ if bbox['label']:
218
+ region_descriptions.append(f"{bbox['label']}")
219
+
220
+ if region_descriptions:
221
+ full_prompt += ", " + ", ".join(region_descriptions)
222
+
223
+ return full_prompt
224
+
225
+ def _generate_with_bbox_control(self, prompt: str, bbox_data: List[Dict],
226
+ negative_prompt: str, num_inference_steps: int,
227
+ guidance_scale: float, width: int, height: int,
228
+ generator: Optional[torch.Generator]) -> Image.Image:
229
+ """使用边界框控制生成图像"""
230
+
231
+ # 创建注意力掩码
232
+ attention_masks = self.create_attention_mask(bbox_data, width, height)
233
+
234
+ # 这里应该实现DreamRenderer的核心算法
235
+ # 包括注意力修改、交叉注意力控制等
236
+
237
+ # 现在先用标准方法生成,后续可以替换为实际的DreamRenderer实现
238
+ image = self.pipe(
239
+ prompt=prompt,
240
+ negative_prompt=negative_prompt,
241
+ num_inference_steps=num_inference_steps,
242
+ guidance_scale=guidance_scale,
243
+ width=width,
244
+ height=height,
245
+ generator=generator
246
+ ).images[0]
247
+
248
+ # 在生成的图像上绘制边界框作为演示
249
+ image = self._add_bbox_overlay(image, bbox_data)
250
+
251
+ return image
252
+
253
+ def _add_bbox_overlay(self, image: Image.Image, bbox_data: List[Dict]) -> Image.Image:
254
+ """在图像上添加边界框覆盖层(用于演示)"""
255
+ if not bbox_data:
256
+ return image
257
+
258
+ draw = ImageDraw.Draw(image)
259
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
260
+
261
+ for i, bbox in enumerate(bbox_data):
262
+ color = colors[i % len(colors)]
263
+
264
+ x = int(bbox['x'] * image.width)
265
+ y = int(bbox['y'] * image.height)
266
+ w = int(bbox['width'] * image.width)
267
+ h = int(bbox['height'] * image.height)
268
+
269
+ # 绘制边界框
270
+ draw.rectangle([x, y, x+w, y+h], outline=color, width=2)
271
+
272
+ # 绘制标签
273
+ if bbox['label']:
274
+ draw.text((x, y-15), bbox['label'], fill=color)
275
+
276
+ return image
277
+
278
+ def _create_demo_image(self, prompt: str, bbox_data: List[Dict],
279
+ width: int, height: int) -> Image.Image:
280
+ """创建演示图像(当模型加载失败时使用)"""
281
+ # 创建一个渐变背景
282
+ image = Image.new('RGB', (width, height))
283
+ draw = ImageDraw.Draw(image)
284
+
285
+ # 绘制渐变背景
286
+ for y in range(height):
287
+ color_value = int(255 * (y / height))
288
+ color = (100 + color_value//3, 150 + color_value//4, 200 + color_value//5)
289
+ draw.line([(0, y), (width, y)], fill=color)
290
+
291
+ # 添加提示词文本
292
+ draw.text((10, 10), f"Prompt: {prompt}", fill='white')
293
+ draw.text((10, 30), "DreamRenderer Demo", fill='white')
294
+
295
+ # 绘制边界框
296
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
297
+ for i, bbox in enumerate(bbox_data):
298
+ color = colors[i % len(colors)]
299
+
300
+ x = int(bbox['x'] * width)
301
+ y = int(bbox['y'] * height)
302
+ w = int(bbox['width'] * width)
303
+ h = int(bbox['height'] * height)
304
+
305
+ # 绘制边界框
306
+ draw.rectangle([x, y, x+w, y+h], outline=color, width=3)
307
+
308
+ # 绘制标签
309
+ if bbox['label']:
310
+ draw.text((x, y-20), bbox['label'], fill=color)
311
+
312
+ return image
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.40.0
2
+ spaces
3
+ torch>=2.0.0
4
+ torchvision
5
+ diffusers>=0.21.0
6
+ transformers>=4.30.0
7
+ accelerate
8
+ pillow
9
+ numpy
10
+ opencv-python-headless