""" DreamRenderer实现模块 """ import torch import torch.nn.functional as F from diffusers import FluxPipeline from PIL import Image, ImageDraw import numpy as np from typing import List, Dict, Optional, Tuple import spaces class DreamRendererPipeline: """ DreamRenderer管道实现 """ def __init__(self, model_id: str = "black-forest-labs/FLUX.1-dev"): """ 初始化DreamRenderer管道 Args: model_id: 使用的模型ID """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_id = model_id self.pipe = None self.loaded = False def load_model(self): """加载FLUX模型""" try: print(f"正在加载模型: {self.model_id}") self.pipe = FluxPipeline.from_pretrained( self.model_id, torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32, use_safetensors=True ) self.pipe = self.pipe.to(self.device) # 启用内存高效的注意力机制 if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'): self.pipe.enable_xformers_memory_efficient_attention() self.loaded = True print("模型加载完成!") return True except Exception as e: print(f"模型加载失败: {str(e)}") self.loaded = False return False def create_layout_mask(self, bbox_data: List[Dict], width: int, height: int) -> torch.Tensor: """ 根据边界框数据创建布局掩码 Args: bbox_data: 边界框数据列表 width: 图像宽度 height: 图像高度 Returns: 布局掩码张量 """ mask = torch.zeros((height, width), dtype=torch.float32) for i, bbox in enumerate(bbox_data): x = int(bbox['x'] * width) y = int(bbox['y'] * height) w = int(bbox['width'] * width) h = int(bbox['height'] * height) # 在掩码中标记区域 mask[y:y+h, x:x+w] = i + 1 return mask def create_attention_mask(self, bbox_data: List[Dict], width: int, height: int) -> List[torch.Tensor]: """ 为每个实例创建注意力掩码 Args: bbox_data: 边界框数据列表 width: 图像宽度 height: 图像高度 Returns: 注意力掩码列表 """ masks = [] for bbox in bbox_data: mask = torch.zeros((height, width), dtype=torch.float32) x = int(bbox['x'] * width) y = int(bbox['y'] * height) w = int(bbox['width'] * width) h = int(bbox['height'] * height) # 创建软边界的掩码 mask[y:y+h, x:x+w] = 1.0 # 应用高斯模糊以创建软边界 if torch.cuda.is_available(): mask = mask.unsqueeze(0).unsqueeze(0).cuda() mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1) mask = mask.squeeze().cpu() masks.append(mask) return masks def modify_attention_weights(self, attention_weights: torch.Tensor, attention_masks: List[torch.Tensor], current_token_idx: int) -> torch.Tensor: """ 修改注意力权重以实现区域控制 Args: attention_weights: 原始注意力权重 attention_masks: 注意力掩码列表 current_token_idx: 当前token索引 Returns: 修改后的注意力权重 """ # 这里实现DreamRenderer的核心注意力修改逻辑 # 根据当前token和对应的区域掩码调整注意力权重 if current_token_idx < len(attention_masks): mask = attention_masks[current_token_idx] # 将掩码应用到注意力权重 if mask.device != attention_weights.device: mask = mask.to(attention_weights.device) # 增强对应区域的注意力 attention_weights = attention_weights * (1 + mask * 0.5) return attention_weights @spaces.GPU def generate_image(self, prompt: str, bbox_data: List[Dict], negative_prompt: str = "", num_inference_steps: int = 20, guidance_scale: float = 7.5, width: int = 512, height: int = 512, seed: Optional[int] = None) -> Image.Image: """ 生成图像的主要函数 Args: prompt: 主提示词 bbox_data: 边界框数据 negative_prompt: 负向提示词 num_inference_steps: 推理步数 guidance_scale: 引导强度 width: 图像宽度 height: 图像高度 seed: 随机种子 Returns: 生成的图像 """ if not self.loaded: if not self.load_model(): # 如果模型加载失败,返回一个演示图像 return self._create_demo_image(prompt, bbox_data, width, height) # 设置随机种子 if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) else: generator = None try: # 构建完整的提示词 full_prompt = self._build_full_prompt(prompt, bbox_data) # 如果没有边界框数据,直接使用标准生成 if not bbox_data: image = self.pipe( prompt=full_prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, generator=generator ).images[0] else: # 使用DreamRenderer的区域控制逻辑 image = self._generate_with_bbox_control( full_prompt, bbox_data, negative_prompt, num_inference_steps, guidance_scale, width, height, generator ) return image except Exception as e: print(f"生成图像时出错: {str(e)}") # 返回演示图像 return self._create_demo_image(prompt, bbox_data, width, height) def _build_full_prompt(self, main_prompt: str, bbox_data: List[Dict]) -> str: """构建包含区域描述的完整提示词""" full_prompt = main_prompt if bbox_data: region_descriptions = [] for i, bbox in enumerate(bbox_data): if bbox['label']: region_descriptions.append(f"{bbox['label']}") if region_descriptions: full_prompt += ", " + ", ".join(region_descriptions) return full_prompt def _generate_with_bbox_control(self, prompt: str, bbox_data: List[Dict], negative_prompt: str, num_inference_steps: int, guidance_scale: float, width: int, height: int, generator: Optional[torch.Generator]) -> Image.Image: """使用边界框控制生成图像""" # 创建注意力掩码 attention_masks = self.create_attention_mask(bbox_data, width, height) # 这里应该实现DreamRenderer的核心算法 # 包括注意力修改、交叉注意力控制等 # 现在先用标准方法生成,后续可以替换为实际的DreamRenderer实现 image = self.pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, generator=generator ).images[0] # 在生成的图像上绘制边界框作为演示 image = self._add_bbox_overlay(image, bbox_data) return image def _add_bbox_overlay(self, image: Image.Image, bbox_data: List[Dict]) -> Image.Image: """在图像上添加边界框覆盖层(用于演示)""" if not bbox_data: return image draw = ImageDraw.Draw(image) colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan'] for i, bbox in enumerate(bbox_data): color = colors[i % len(colors)] x = int(bbox['x'] * image.width) y = int(bbox['y'] * image.height) w = int(bbox['width'] * image.width) h = int(bbox['height'] * image.height) # 绘制边界框 draw.rectangle([x, y, x+w, y+h], outline=color, width=2) # 绘制标签 if bbox['label']: draw.text((x, y-15), bbox['label'], fill=color) return image def _create_demo_image(self, prompt: str, bbox_data: List[Dict], width: int, height: int) -> Image.Image: """创建演示图像(当模型加载失败时使用)""" # 创建一个渐变背景 image = Image.new('RGB', (width, height)) draw = ImageDraw.Draw(image) # 绘制渐变背景 for y in range(height): color_value = int(255 * (y / height)) color = (100 + color_value//3, 150 + color_value//4, 200 + color_value//5) draw.line([(0, y), (width, y)], fill=color) # 添加提示词文本 draw.text((10, 10), f"Prompt: {prompt}", fill='white') draw.text((10, 30), "DreamRenderer Demo", fill='white') # 绘制边界框 colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange'] for i, bbox in enumerate(bbox_data): color = colors[i % len(colors)] x = int(bbox['x'] * width) y = int(bbox['y'] * height) w = int(bbox['width'] * width) h = int(bbox['height'] * height) # 绘制边界框 draw.rectangle([x, y, x+w, y+h], outline=color, width=3) # 绘制标签 if bbox['label']: draw.text((x, y-20), bbox['label'], fill=color) return image