DreamRenderer / dream_renderer.py
Longxiang-ai's picture
Initial commit: DreamRenderer with Zero GPU support
0274afd
raw
history blame
11.2 kB
"""
DreamRenderer实现模块
"""
import torch
import torch.nn.functional as F
from diffusers import FluxPipeline
from PIL import Image, ImageDraw
import numpy as np
from typing import List, Dict, Optional, Tuple
import spaces
class DreamRendererPipeline:
"""
DreamRenderer管道实现
"""
def __init__(self, model_id: str = "black-forest-labs/FLUX.1-dev"):
"""
初始化DreamRenderer管道
Args:
model_id: 使用的模型ID
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_id = model_id
self.pipe = None
self.loaded = False
def load_model(self):
"""加载FLUX模型"""
try:
print(f"正在加载模型: {self.model_id}")
self.pipe = FluxPipeline.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
use_safetensors=True
)
self.pipe = self.pipe.to(self.device)
# 启用内存高效的注意力机制
if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'):
self.pipe.enable_xformers_memory_efficient_attention()
self.loaded = True
print("模型加载完成!")
return True
except Exception as e:
print(f"模型加载失败: {str(e)}")
self.loaded = False
return False
def create_layout_mask(self, bbox_data: List[Dict], width: int, height: int) -> torch.Tensor:
"""
根据边界框数据创建布局掩码
Args:
bbox_data: 边界框数据列表
width: 图像宽度
height: 图像高度
Returns:
布局掩码张量
"""
mask = torch.zeros((height, width), dtype=torch.float32)
for i, bbox in enumerate(bbox_data):
x = int(bbox['x'] * width)
y = int(bbox['y'] * height)
w = int(bbox['width'] * width)
h = int(bbox['height'] * height)
# 在掩码中标记区域
mask[y:y+h, x:x+w] = i + 1
return mask
def create_attention_mask(self, bbox_data: List[Dict], width: int, height: int) -> List[torch.Tensor]:
"""
为每个实例创建注意力掩码
Args:
bbox_data: 边界框数据列表
width: 图像宽度
height: 图像高度
Returns:
注意力掩码列表
"""
masks = []
for bbox in bbox_data:
mask = torch.zeros((height, width), dtype=torch.float32)
x = int(bbox['x'] * width)
y = int(bbox['y'] * height)
w = int(bbox['width'] * width)
h = int(bbox['height'] * height)
# 创建软边界的掩码
mask[y:y+h, x:x+w] = 1.0
# 应用高斯模糊以创建软边界
if torch.cuda.is_available():
mask = mask.unsqueeze(0).unsqueeze(0).cuda()
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1)
mask = mask.squeeze().cpu()
masks.append(mask)
return masks
def modify_attention_weights(self, attention_weights: torch.Tensor,
attention_masks: List[torch.Tensor],
current_token_idx: int) -> torch.Tensor:
"""
修改注意力权重以实现区域控制
Args:
attention_weights: 原始注意力权重
attention_masks: 注意力掩码列表
current_token_idx: 当前token索引
Returns:
修改后的注意力权重
"""
# 这里实现DreamRenderer的核心注意力修改逻辑
# 根据当前token和对应的区域掩码调整注意力权重
if current_token_idx < len(attention_masks):
mask = attention_masks[current_token_idx]
# 将掩码应用到注意力权重
if mask.device != attention_weights.device:
mask = mask.to(attention_weights.device)
# 增强对应区域的注意力
attention_weights = attention_weights * (1 + mask * 0.5)
return attention_weights
@spaces.GPU
def generate_image(self,
prompt: str,
bbox_data: List[Dict],
negative_prompt: str = "",
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
width: int = 512,
height: int = 512,
seed: Optional[int] = None) -> Image.Image:
"""
生成图像的主要函数
Args:
prompt: 主提示词
bbox_data: 边界框数据
negative_prompt: 负向提示词
num_inference_steps: 推理步数
guidance_scale: 引导强度
width: 图像宽度
height: 图像高度
seed: 随机种子
Returns:
生成的图像
"""
if not self.loaded:
if not self.load_model():
# 如果模型加载失败,返回一个演示图像
return self._create_demo_image(prompt, bbox_data, width, height)
# 设置随机种子
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
else:
generator = None
try:
# 构建完整的提示词
full_prompt = self._build_full_prompt(prompt, bbox_data)
# 如果没有边界框数据,直接使用标准生成
if not bbox_data:
image = self.pipe(
prompt=full_prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
generator=generator
).images[0]
else:
# 使用DreamRenderer的区域控制逻辑
image = self._generate_with_bbox_control(
full_prompt, bbox_data, negative_prompt,
num_inference_steps, guidance_scale,
width, height, generator
)
return image
except Exception as e:
print(f"生成图像时出错: {str(e)}")
# 返回演示图像
return self._create_demo_image(prompt, bbox_data, width, height)
def _build_full_prompt(self, main_prompt: str, bbox_data: List[Dict]) -> str:
"""构建包含区域描述的完整提示词"""
full_prompt = main_prompt
if bbox_data:
region_descriptions = []
for i, bbox in enumerate(bbox_data):
if bbox['label']:
region_descriptions.append(f"{bbox['label']}")
if region_descriptions:
full_prompt += ", " + ", ".join(region_descriptions)
return full_prompt
def _generate_with_bbox_control(self, prompt: str, bbox_data: List[Dict],
negative_prompt: str, num_inference_steps: int,
guidance_scale: float, width: int, height: int,
generator: Optional[torch.Generator]) -> Image.Image:
"""使用边界框控制生成图像"""
# 创建注意力掩码
attention_masks = self.create_attention_mask(bbox_data, width, height)
# 这里应该实现DreamRenderer的核心算法
# 包括注意力修改、交叉注意力控制等
# 现在先用标准方法生成,后续可以替换为实际的DreamRenderer实现
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
generator=generator
).images[0]
# 在生成的图像上绘制边界框作为演示
image = self._add_bbox_overlay(image, bbox_data)
return image
def _add_bbox_overlay(self, image: Image.Image, bbox_data: List[Dict]) -> Image.Image:
"""在图像上添加边界框覆盖层(用于演示)"""
if not bbox_data:
return image
draw = ImageDraw.Draw(image)
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
for i, bbox in enumerate(bbox_data):
color = colors[i % len(colors)]
x = int(bbox['x'] * image.width)
y = int(bbox['y'] * image.height)
w = int(bbox['width'] * image.width)
h = int(bbox['height'] * image.height)
# 绘制边界框
draw.rectangle([x, y, x+w, y+h], outline=color, width=2)
# 绘制标签
if bbox['label']:
draw.text((x, y-15), bbox['label'], fill=color)
return image
def _create_demo_image(self, prompt: str, bbox_data: List[Dict],
width: int, height: int) -> Image.Image:
"""创建演示图像(当模型加载失败时使用)"""
# 创建一个渐变背景
image = Image.new('RGB', (width, height))
draw = ImageDraw.Draw(image)
# 绘制渐变背景
for y in range(height):
color_value = int(255 * (y / height))
color = (100 + color_value//3, 150 + color_value//4, 200 + color_value//5)
draw.line([(0, y), (width, y)], fill=color)
# 添加提示词文本
draw.text((10, 10), f"Prompt: {prompt}", fill='white')
draw.text((10, 30), "DreamRenderer Demo", fill='white')
# 绘制边界框
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
for i, bbox in enumerate(bbox_data):
color = colors[i % len(colors)]
x = int(bbox['x'] * width)
y = int(bbox['y'] * height)
w = int(bbox['width'] * width)
h = int(bbox['height'] * height)
# 绘制边界框
draw.rectangle([x, y, x+w, y+h], outline=color, width=3)
# 绘制标签
if bbox['label']:
draw.text((x, y-20), bbox['label'], fill=color)
return image