File size: 6,211 Bytes
43b3c10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from diffusers import UNet2DModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union
from collections import OrderedDict
from dataclasses import dataclass
from datasets import load_dataset
import matplotlib.pyplot as plt
from torchvision import transforms
from functools import partial
import torch
from torch.utils.data import DataLoader
from PIL import Image
from diffusers import DDPMScheduler
import torch.nn.functional as F
from accelerate import Accelerator
from diffusers import DDPMPipeline
import os
from huggingface_hub import create_repo, upload_folder
class DPM(UNet2DModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# créer bottleneck_attn ici (selon ton architecture)
self.bottleneck_attn = nn.MultiheadAttention(
embed_dim=self.config.block_out_channels[-1],
num_heads=8, # ou ajuster selon besoin
batch_first=True
)
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
prototype: Optional[torch.Tensor] = None, # <--- ajouté ici
) -> Union[UNet2DOutput, Tuple]:
r"""
The [`UNet2DModel`] forward method.
Args:
sample (`torch.Tensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unets.unet_2d.UNet2DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# ----------- Cross-Attention after downsampling ------------------
if prototype is None:
raise ValueError("You must provide a `prototype` tensor for cross-attention")
b, c, h, w = sample.shape
query = sample.view(b, c, h * w).transpose(1, 2) # (B, HW, C)
# prototype: expected shape (B, N, C)
key = value = prototype.to(dtype=sample.dtype)
attn_output, _ = self.bottleneck_attn(query, key, value)
attn_output = attn_output.transpose(1, 2).view(b, c, h, w) # (B, C, H, W)
# Résiduel
sample = sample + attn_output
# ---------------------------------------------------------------
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
if not return_dict:
return (sample,)
return UNet2DOutput(sample=sample) |