Klayand commited on
Commit
949adbc
·
1 Parent(s): ad68b20

add inference code

Browse files
README.md CHANGED
@@ -1,3 +1,91 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # NPNet Pipeline Usage Guide😄
6
+
7
+ ## Overview
8
+
9
+ This guide provides instructions on how to use the NPNet, a noise prompt network aims to transform the random Gaussian noise into golden noise, by adding a small desirable perturbation derived from the text prompt to boost the overall quality and semantic faithfulness of the synthesized images.
10
+
11
+ Here we provide the inference code which supports different models like ***Stable Diffusion XL, DreamShaper-xl-v2-turbo, and Hunyuan-DiT.***
12
+
13
+
14
+ ## Requirements
15
+
16
+ - `python >= 3.8.0`
17
+ - `pytorch with cuda version`
18
+ - `diffusers`
19
+ - `PIL`
20
+ - `numpy`
21
+ - `timm`
22
+ - `argparse`
23
+ - `einops`
24
+
25
+ ## Installation🚀️
26
+
27
+ Make sure you have successfully built `python` environment and installed `pytorch` with cuda version. Before running the script, ensure you have all the required packages installed. You can install them using:
28
+
29
+ ```bash
30
+ pip install diffusers, PIL, numpy, timm, argparse, einops
31
+ ```
32
+
33
+ ## Usage👀️
34
+
35
+ To use the NPNet pipeline, you need to run the `npnet_pipeline.py` script with appropriate command-line arguments. Below are the available options:
36
+
37
+ ### Command-Line Arguments
38
+
39
+ - `--pipeline`: Select the model pipeline (`SDXL`, `DreamShaper`, `DiT`). Default is `SDXL`.
40
+ - `--prompt`: The textual prompt based on which the image will be generated. Default is "A banana on the left of an apple."
41
+ - `--inference-step`: Number of inference steps for the diffusion process. Default is 50.
42
+ - `--cfg`: Classifier-free guidance scale. Default is 5.5.
43
+ - `--pretrained-path`: Path to the pretrained model weights. Default is a specified path in the script.
44
+ - `--size`: The size (height and width) of the generated image. Default is 1024.
45
+
46
+ ### Running the Script
47
+
48
+ Run the script from the command line by navigating to the directory containing `npnet_pipeline.py` and executing:
49
+
50
+ ```
51
+ python npnet_pipeline.py --pipeline SDXL --prompt "A banana on the left of an apple." --size 1024
52
+ ```
53
+
54
+ This command will generate an image based on the prompt "A banana on the left of an apple." using the Stable Diffusion XL model with an image size of 1024x1024 pixels.
55
+
56
+ ### Output🎉️
57
+
58
+ The script will save two images:
59
+
60
+ - A standard image generated by the diffusion model.
61
+ - A golden image generated by the diffusion model with the NPNet.
62
+
63
+ Both images will be saved in the current directory with names based on the model and prompt.
64
+
65
+ ## Pre-trained Weights Download❤️
66
+
67
+ We provide the pre-trained NPNet weights of Stable Diffusion XL, DreamShaper-xl-v2-turbo, and Hunyuan-DiT with [google drive](https://drive.google.com/drive/folders/1Z0wg4HADhpgrztyT3eWijPbJJN5Y2jQt?usp=drive_link)
68
+
69
+ ## Citation:
70
+ If you find our code useful for your research, please cite our paper.
71
+
72
+ ```
73
+ @misc{zhou2024goldennoisediffusionmodels,
74
+ title={Golden Noise for Diffusion Models: A Learning Framework},
75
+ author={Zikai Zhou and Shitong Shao and Lichen Bai and Zhiqiang Xu and Bo Han and Zeke Xie},
76
+ year={2024},
77
+ eprint={2411.09502},
78
+ archivePrefix={arXiv},
79
+ primaryClass={cs.LG},
80
+ url={https://arxiv.org/abs/2411.09502},
81
+ }
82
+ ```
83
+
84
+ ## 🙏 Acknowledgements
85
+
86
+ We thank the community and contributors for their invaluable support in developing NPNet.
87
+ We thank @DataCTE for constructing the ComfyUI of NPNet inference code [ComfyUI](https://github.com/DataCTE/ComfyUI_Golden-Noise).
88
+ We thank @asagi4 for constructing the ComfyUI of NPNet inference code [ComfyUI](https://github.com/asagi4/ComfyUI-NPNet).
89
+
90
+ ---
91
+
model/NoiseTransformer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from torch.nn import functional as F
4
+ from timm import create_model
5
+
6
+
7
+ __all__ = ['NoiseTransformer']
8
+
9
+ class NoiseTransformer(nn.Module):
10
+ def __init__(self, resolution=128):
11
+ super().__init__()
12
+ self.upsample = lambda x: F.interpolate(x, [224,224])
13
+ self.downsample = lambda x: F.interpolate(x, [resolution,resolution])
14
+ self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
15
+ self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
16
+ # self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
17
+ self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
18
+
19
+
20
+ def forward(self, x, residual=False):
21
+ if residual:
22
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
23
+ else:
24
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
25
+
26
+ return x
model/SVDNoiseUnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import einops
4
+
5
+ from torch.nn import functional as F
6
+ from torch.jit import Final
7
+ from timm.layers import use_fused_attn
8
+ from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
9
+
10
+ __all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
11
+
12
+ class Attention(nn.Module):
13
+ fused_attn: Final[bool]
14
+
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ num_heads: int = 8,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ attn_drop: float = 0.,
22
+ proj_drop: float = 0.,
23
+ norm_layer: nn.Module = nn.LayerNorm,
24
+ ) -> None:
25
+ super().__init__()
26
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
27
+ self.num_heads = num_heads
28
+ self.head_dim = dim // num_heads
29
+ self.scale = self.head_dim ** -0.5
30
+ self.fused_attn = use_fused_attn()
31
+
32
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
33
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
34
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
35
+ self.attn_drop = nn.Dropout(attn_drop)
36
+ self.proj = nn.Linear(dim, dim)
37
+ self.proj_drop = nn.Dropout(proj_drop)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ B, N, C = x.shape
41
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
42
+ q, k, v = qkv.unbind(0)
43
+ q, k = self.q_norm(q), self.k_norm(k)
44
+
45
+ if self.fused_attn:
46
+ x = F.scaled_dot_product_attention(
47
+ q, k, v,
48
+ dropout_p=self.attn_drop.p if self.training else 0.,
49
+ )
50
+ else:
51
+ q = q * self.scale
52
+ attn = q @ k.transpose(-2, -1)
53
+ attn = attn.softmax(dim=-1)
54
+ attn = self.attn_drop(attn)
55
+ x = attn @ v
56
+
57
+ x = x.transpose(1, 2).reshape(B, N, C)
58
+ x = self.proj(x)
59
+ x = self.proj_drop(x)
60
+ return x
61
+
62
+
63
+ class SVDNoiseUnet(nn.Module):
64
+ def __init__(self, in_channels=4, out_channels=4, resolution=128): # resolution = size // 8
65
+ super(SVDNoiseUnet, self).__init__()
66
+
67
+ _in = int(resolution * in_channels // 2)
68
+ _out = int(resolution * out_channels // 2)
69
+ self.mlp1 = nn.Sequential(
70
+ nn.Linear(_in, 64),
71
+ nn.ReLU(inplace=True),
72
+ nn.Linear(64, _out),
73
+ )
74
+ self.mlp2 = nn.Sequential(
75
+ nn.Linear(_in, 64),
76
+ nn.ReLU(inplace=True),
77
+ nn.Linear(64, _out),
78
+ )
79
+
80
+ self.mlp3 = nn.Sequential(
81
+ nn.Linear(_in, _out),
82
+ )
83
+
84
+ self.attention = Attention(_out)
85
+
86
+ self.bn = nn.BatchNorm2d(_out)
87
+
88
+ self.mlp4 = nn.Sequential(
89
+ nn.Linear(_out, 1024),
90
+ nn.ReLU(inplace=True),
91
+ nn.Linear(1024, _out),
92
+ )
93
+
94
+ def forward(self, x, residual=False):
95
+ b, c, h, w = x.shape
96
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
97
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
98
+ U_T = U.permute(0, 2, 1)
99
+ out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
100
+ out = self.attention(out).mean(1)
101
+ out = self.mlp4(out) + s
102
+ pred = U @ torch.diag_embed(out) @ V
103
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
104
+
105
+
106
+ class SVDNoiseUnet_Concise(nn.Module):
107
+ def __init__(self, in_channels=4, out_channels=4, resolution=128):
108
+ super(SVDNoiseUnet_Concise, self).__init__()
109
+
model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .NoiseTransformer import NoiseTransformer
2
+ from .SVDNoiseUnet import SVDNoiseUnet
npnet_pipeline.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import argparse
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+ from PIL import Image
8
+ from diffusers.models.normalization import AdaGroupNorm
9
+ from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, \
10
+ DDPMScheduler, StableDiffusionXLPipeline, HunyuanDiTPipeline
11
+
12
+
13
+ from model import NoiseTransformer, SVDNoiseUnet
14
+
15
+
16
+ class NPNet(nn.Module):
17
+ def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
18
+ super(NPNet, self).__init__()
19
+
20
+ assert model_id in ['SDXL', 'DreamShaper', 'DiT']
21
+
22
+ self.model_id = model_id
23
+ self.device = device
24
+ self.pretrained_path = pretrained_path
25
+
26
+ (
27
+ self.unet_svd,
28
+ self.unet_embedding,
29
+ self.text_embedding,
30
+ self._alpha,
31
+ self._beta
32
+ ) = self.get_model()
33
+
34
+ def get_model(self):
35
+
36
+ unet_embedding = NoiseTransformer(resolution=128).to(self.device).to(torch.float32)
37
+ unet_svd = SVDNoiseUnet(resolution=128).to(self.device).to(torch.float32)
38
+
39
+ if self.model_id == 'DiT':
40
+ text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
41
+ else:
42
+ text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
43
+
44
+
45
+ if '.pth' in self.pretrained_path:
46
+ gloden_unet = torch.load(self.pretrained_path)
47
+ unet_svd.load_state_dict(gloden_unet["unet_svd"])
48
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
49
+ text_embedding.load_state_dict(gloden_unet["embeeding"])
50
+ _alpha = gloden_unet["alpha"]
51
+ _beta = gloden_unet["beta"]
52
+
53
+ print("Load Successfully!")
54
+
55
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
56
+
57
+ else:
58
+ assert ("No Pretrained Weights Found!")
59
+
60
+
61
+ def forward(self, initial_noise, prompt_embeds):
62
+
63
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
64
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
65
+
66
+ encoder_hidden_states_svd = initial_noise
67
+ encoder_hidden_states_embedding = initial_noise + text_emb
68
+
69
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
70
+
71
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
72
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
73
+
74
+ return golden_noise
75
+
76
+
77
+ def get_args():
78
+ parser = argparse.ArgumentParser()
79
+
80
+ # model and dataset construction
81
+ parser.add_argument('--pipeline', default='SDXL',
82
+ choices=['SDXL', 'DreamShaper', 'DiT'], type=str)
83
+ parser.add_argument('--prompt', default='A banana on the left of an apple.', type=str)
84
+ parser.add_argument("--inference-step", default=50, type=int)
85
+
86
+ # for dreamershaper is 3.5, remaining is 5.5, DiT is 5.0
87
+ parser.add_argument("--cfg", default=5.5, type=float)
88
+
89
+ # model pretrained weight path
90
+ parser.add_argument('--pretrained-path', type=str,
91
+ default='xxx')
92
+
93
+ parser.add_argument("--size", default=1024, type=int)
94
+
95
+ args = parser.parse_args()
96
+
97
+ print("generating config:")
98
+ print(f"Config: {args}")
99
+ print('-' * 100)
100
+
101
+ return args
102
+
103
+
104
+ def main(args):
105
+ dtype = torch.float16
106
+ device = torch.device('cuda')
107
+
108
+ if args.pipeline == 'SDXL':
109
+
110
+ pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
111
+ variant="fp16",use_safetensors=True,
112
+ torch_dtype=torch.float16).to(device)
113
+
114
+ elif args.pipeline == 'DreamShaper':
115
+ pipe = StableDiffusionXLPipeline.from_pretrained("lykon/dreamshaper-xl-v2-turbo",
116
+ torch_dtype=torch.float16,
117
+ variant="fp16").to(device)
118
+
119
+ else:
120
+ pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers",
121
+ torch_dtype=torch.float16).to(device)
122
+
123
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
124
+ pipe.enable_model_cpu_offload()
125
+
126
+ # create the initial noise
127
+ latent = torch.randn(1, 4, 128, 128, dtype=dtype).to(device)
128
+
129
+
130
+ # use the pre-trained text encoder in T2I models to encode prompts
131
+ prompt_embeds, _, _, _= pipe.encode_prompt(prompt=args.prompt, device=device)
132
+
133
+ # create NPNet to get the target noise
134
+ npn_net = NPNet(args.pipeline, args.pretrained_path)
135
+
136
+ golden_noise = npn_net(latent, prompt_embeds)
137
+
138
+ # standard inference pipeline
139
+ latent = latent.half()
140
+ golden_noise = golden_noise.half()
141
+
142
+ pipe = pipe.to(torch.float16)
143
+
144
+ standard_img = pipe(
145
+ prompt=args.prompt,
146
+ height=args.size,
147
+ width=args.size,
148
+ num_inference_steps=args.inference_step,
149
+ guidance_scale=args.cfg,
150
+ latents=latent).images[0]
151
+
152
+ golden_img = pipe(
153
+ prompt=args.prompt,
154
+ height=args.size,
155
+ width=args.size,
156
+ num_inference_steps=args.inference_step,
157
+ guidance_scale=args.cfg,
158
+ latents=golden_noise).images[0]
159
+
160
+ # image save path
161
+ standard_img.save(f"{args.pipeline}_{args.prompt}_standard_image.jpg")
162
+ golden_img.save(f"{args.pipeline}_{args.prompt}_golden_image.jpg")
163
+
164
+
165
+ if __name__ == '__main__':
166
+ args = get_args()
167
+ main(args)
168
+
weights/dit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2de4810f96d1c60682632d72207f202af9a862c2f9268be7f24c8c56aec5b5d
3
+ size 119449411
weights/dreamshaper.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bb268ce36e851549831ad7933b26a7597b325c19f761738fbd95d58f57cc41d
3
+ size 121965392
weights/sdxl.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a8629a4d939f9ed8f02ed2ad39b8317b701fb9e59d175ce186512e4a2687e48
3
+ size 121965599