Spaces:
Sleeping
Sleeping
""" | |
DreamRenderer实现模块 | |
""" | |
import torch | |
import torch.nn.functional as F | |
from diffusers import FluxPipeline | |
from PIL import Image, ImageDraw | |
import numpy as np | |
from typing import List, Dict, Optional, Tuple | |
import spaces | |
class DreamRendererPipeline: | |
""" | |
DreamRenderer管道实现 | |
""" | |
def __init__(self, model_id: str = "black-forest-labs/FLUX.1-dev"): | |
""" | |
初始化DreamRenderer管道 | |
Args: | |
model_id: 使用的模型ID | |
""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model_id = model_id | |
self.pipe = None | |
self.loaded = False | |
def load_model(self): | |
"""加载FLUX模型""" | |
try: | |
print(f"正在加载模型: {self.model_id}") | |
self.pipe = FluxPipeline.from_pretrained( | |
self.model_id, | |
torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32, | |
use_safetensors=True | |
) | |
self.pipe = self.pipe.to(self.device) | |
# 启用内存高效的注意力机制 | |
if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'): | |
self.pipe.enable_xformers_memory_efficient_attention() | |
self.loaded = True | |
print("模型加载完成!") | |
return True | |
except Exception as e: | |
print(f"模型加载失败: {str(e)}") | |
self.loaded = False | |
return False | |
def create_layout_mask(self, bbox_data: List[Dict], width: int, height: int) -> torch.Tensor: | |
""" | |
根据边界框数据创建布局掩码 | |
Args: | |
bbox_data: 边界框数据列表 | |
width: 图像宽度 | |
height: 图像高度 | |
Returns: | |
布局掩码张量 | |
""" | |
mask = torch.zeros((height, width), dtype=torch.float32) | |
for i, bbox in enumerate(bbox_data): | |
x = int(bbox['x'] * width) | |
y = int(bbox['y'] * height) | |
w = int(bbox['width'] * width) | |
h = int(bbox['height'] * height) | |
# 在掩码中标记区域 | |
mask[y:y+h, x:x+w] = i + 1 | |
return mask | |
def create_attention_mask(self, bbox_data: List[Dict], width: int, height: int) -> List[torch.Tensor]: | |
""" | |
为每个实例创建注意力掩码 | |
Args: | |
bbox_data: 边界框数据列表 | |
width: 图像宽度 | |
height: 图像高度 | |
Returns: | |
注意力掩码列表 | |
""" | |
masks = [] | |
for bbox in bbox_data: | |
mask = torch.zeros((height, width), dtype=torch.float32) | |
x = int(bbox['x'] * width) | |
y = int(bbox['y'] * height) | |
w = int(bbox['width'] * width) | |
h = int(bbox['height'] * height) | |
# 创建软边界的掩码 | |
mask[y:y+h, x:x+w] = 1.0 | |
# 应用高斯模糊以创建软边界 | |
if torch.cuda.is_available(): | |
mask = mask.unsqueeze(0).unsqueeze(0).cuda() | |
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1) | |
mask = mask.squeeze().cpu() | |
masks.append(mask) | |
return masks | |
def modify_attention_weights(self, attention_weights: torch.Tensor, | |
attention_masks: List[torch.Tensor], | |
current_token_idx: int) -> torch.Tensor: | |
""" | |
修改注意力权重以实现区域控制 | |
Args: | |
attention_weights: 原始注意力权重 | |
attention_masks: 注意力掩码列表 | |
current_token_idx: 当前token索引 | |
Returns: | |
修改后的注意力权重 | |
""" | |
# 这里实现DreamRenderer的核心注意力修改逻辑 | |
# 根据当前token和对应的区域掩码调整注意力权重 | |
if current_token_idx < len(attention_masks): | |
mask = attention_masks[current_token_idx] | |
# 将掩码应用到注意力权重 | |
if mask.device != attention_weights.device: | |
mask = mask.to(attention_weights.device) | |
# 增强对应区域的注意力 | |
attention_weights = attention_weights * (1 + mask * 0.5) | |
return attention_weights | |
def generate_image(self, | |
prompt: str, | |
bbox_data: List[Dict], | |
negative_prompt: str = "", | |
num_inference_steps: int = 20, | |
guidance_scale: float = 7.5, | |
width: int = 512, | |
height: int = 512, | |
seed: Optional[int] = None) -> Image.Image: | |
""" | |
生成图像的主要函数 | |
Args: | |
prompt: 主提示词 | |
bbox_data: 边界框数据 | |
negative_prompt: 负向提示词 | |
num_inference_steps: 推理步数 | |
guidance_scale: 引导强度 | |
width: 图像宽度 | |
height: 图像高度 | |
seed: 随机种子 | |
Returns: | |
生成的图像 | |
""" | |
if not self.loaded: | |
if not self.load_model(): | |
# 如果模型加载失败,返回一个演示图像 | |
return self._create_demo_image(prompt, bbox_data, width, height) | |
# 设置随机种子 | |
if seed is not None: | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
else: | |
generator = None | |
try: | |
# 构建完整的提示词 | |
full_prompt = self._build_full_prompt(prompt, bbox_data) | |
# 如果没有边界框数据,直接使用标准生成 | |
if not bbox_data: | |
image = self.pipe( | |
prompt=full_prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
width=width, | |
height=height, | |
generator=generator | |
).images[0] | |
else: | |
# 使用DreamRenderer的区域控制逻辑 | |
image = self._generate_with_bbox_control( | |
full_prompt, bbox_data, negative_prompt, | |
num_inference_steps, guidance_scale, | |
width, height, generator | |
) | |
return image | |
except Exception as e: | |
print(f"生成图像时出错: {str(e)}") | |
# 返回演示图像 | |
return self._create_demo_image(prompt, bbox_data, width, height) | |
def _build_full_prompt(self, main_prompt: str, bbox_data: List[Dict]) -> str: | |
"""构建包含区域描述的完整提示词""" | |
full_prompt = main_prompt | |
if bbox_data: | |
region_descriptions = [] | |
for i, bbox in enumerate(bbox_data): | |
if bbox['label']: | |
region_descriptions.append(f"{bbox['label']}") | |
if region_descriptions: | |
full_prompt += ", " + ", ".join(region_descriptions) | |
return full_prompt | |
def _generate_with_bbox_control(self, prompt: str, bbox_data: List[Dict], | |
negative_prompt: str, num_inference_steps: int, | |
guidance_scale: float, width: int, height: int, | |
generator: Optional[torch.Generator]) -> Image.Image: | |
"""使用边界框控制生成图像""" | |
# 创建注意力掩码 | |
attention_masks = self.create_attention_mask(bbox_data, width, height) | |
# 这里应该实现DreamRenderer的核心算法 | |
# 包括注意力修改、交叉注意力控制等 | |
# 现在先用标准方法生成,后续可以替换为实际的DreamRenderer实现 | |
image = self.pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
width=width, | |
height=height, | |
generator=generator | |
).images[0] | |
# 在生成的图像上绘制边界框作为演示 | |
image = self._add_bbox_overlay(image, bbox_data) | |
return image | |
def _add_bbox_overlay(self, image: Image.Image, bbox_data: List[Dict]) -> Image.Image: | |
"""在图像上添加边界框覆盖层(用于演示)""" | |
if not bbox_data: | |
return image | |
draw = ImageDraw.Draw(image) | |
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan'] | |
for i, bbox in enumerate(bbox_data): | |
color = colors[i % len(colors)] | |
x = int(bbox['x'] * image.width) | |
y = int(bbox['y'] * image.height) | |
w = int(bbox['width'] * image.width) | |
h = int(bbox['height'] * image.height) | |
# 绘制边界框 | |
draw.rectangle([x, y, x+w, y+h], outline=color, width=2) | |
# 绘制标签 | |
if bbox['label']: | |
draw.text((x, y-15), bbox['label'], fill=color) | |
return image | |
def _create_demo_image(self, prompt: str, bbox_data: List[Dict], | |
width: int, height: int) -> Image.Image: | |
"""创建演示图像(当模型加载失败时使用)""" | |
# 创建一个渐变背景 | |
image = Image.new('RGB', (width, height)) | |
draw = ImageDraw.Draw(image) | |
# 绘制渐变背景 | |
for y in range(height): | |
color_value = int(255 * (y / height)) | |
color = (100 + color_value//3, 150 + color_value//4, 200 + color_value//5) | |
draw.line([(0, y), (width, y)], fill=color) | |
# 添加提示词文本 | |
draw.text((10, 10), f"Prompt: {prompt}", fill='white') | |
draw.text((10, 30), "DreamRenderer Demo", fill='white') | |
# 绘制边界框 | |
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange'] | |
for i, bbox in enumerate(bbox_data): | |
color = colors[i % len(colors)] | |
x = int(bbox['x'] * width) | |
y = int(bbox['y'] * height) | |
w = int(bbox['width'] * width) | |
h = int(bbox['height'] * height) | |
# 绘制边界框 | |
draw.rectangle([x, y, x+w, y+h], outline=color, width=3) | |
# 绘制标签 | |
if bbox['label']: | |
draw.text((x, y-20), bbox['label'], fill=color) | |
return image |