songtianhui commited on
Commit
439279b
·
1 Parent(s): f063156
Files changed (5) hide show
  1. README.md +5 -4
  2. app.py +77 -0
  3. modeling/dmm_pipeline.py +326 -0
  4. modeling/dmm_unet.py +0 -0
  5. requirements.txt +13 -0
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: DMM
3
- emoji: 🌖
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.25.2
8
  app_file: app.py
9
  pinned: false
 
10
  short_description: Demo for paper DMM
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: DMM
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.25.2
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  short_description: Demo for paper DMM
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ import gradio as gr
4
+
5
+ from modeling.dmm_pipeline import StableDiffusionDMMPipeline
6
+ from huggingface_hub import snapshot_download
7
+
8
+
9
+ ckpt_path = "ckpt"
10
+ snapshot_download(repo_id="MCG-NJU/DMM", local_dir=ckpt_path)
11
+
12
+ pipe = StableDiffusionDMMPipeline.from_pretrained(
13
+ ckpt_path,
14
+ torch_dtype=torch.float16,
15
+ use_safetensors=True
16
+ )
17
+ pipe.to("cuda")
18
+
19
+ @spaces.GPU
20
+ def generate(prompt: str,
21
+ negative_prompt: str,
22
+ model_id: int,
23
+ seed: int = 1234,
24
+ all: bool = True):
25
+ if all:
26
+ outputs = []
27
+ for i in range(pipe.unet.get_num_models()):
28
+ output = pipe(
29
+ prompt=prompt,
30
+ negative_prompt=negative_prompt,
31
+ width=512,
32
+ height=512,
33
+ num_inference_steps=25,
34
+ guidance_scale=7,
35
+ model_id=i,
36
+ generator=torch.Generator().manual_seed(seed),
37
+ ).images[0]
38
+ outputs.append(output)
39
+ return outputs
40
+ else:
41
+ output = pipe(
42
+ prompt=prompt,
43
+ negative_prompt=negative_prompt,
44
+ width=512,
45
+ height=512,
46
+ num_inference_steps=25,
47
+ guidance_scale=7,
48
+ model_id=int(model_id),
49
+ generator=torch.Generator().manual_seed(seed),
50
+ ).images[0]
51
+ return [output,]
52
+
53
+
54
+
55
+ def main():
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("# DMM")
58
+ with gr.Row():
59
+ with gr.Column():
60
+ prompt = gr.Textbox("portrait photo of a girl, long golden hair, flowers, best quality", label="Prompt")
61
+ negative_prompt = gr.Textbox("worst quality,low quality,normal quality,lowres,watermark,nsfw", label="Negative Prompt")
62
+ seed = gr.Number(1234, label="Seed", precision=0)
63
+ with gr.Column():
64
+ model_id = gr.Slider(label="Model Index", minimum=0, maximum=7, step=1)
65
+ all_check = gr.Checkbox(label="All")
66
+ btn = gr.Button("Submit", variant="primary")
67
+ output = gr.Gallery(label="images")
68
+
69
+ btn.click(generate,
70
+ inputs=[prompt, negative_prompt, model_id, seed, all_check],
71
+ outputs=[output])
72
+
73
+ demo.launch()
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
modeling/dmm_pipeline.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.image_processor import PipelineImageInput
6
+ from diffusers.utils import (
7
+ deprecate,
8
+ logging,
9
+ replace_example_docstring,
10
+ )
11
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, retrieve_timesteps, rescale_noise_cfg
12
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
13
+
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+ EXAMPLE_DOC_STRING = """
18
+ Examples:
19
+ ```py
20
+ >>> import torch
21
+ >>> from diffusers import StableDiffusionPipeline
22
+
23
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
24
+ >>> pipe = pipe.to("cuda")
25
+
26
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
27
+ >>> image = pipe(prompt).images[0]
28
+ ```
29
+ """
30
+
31
+
32
+ class StableDiffusionDMMPipeline(StableDiffusionPipeline):
33
+
34
+ @torch.no_grad()
35
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
36
+ def __call__(
37
+ self,
38
+ prompt: Union[str, List[str]] = None,
39
+ height: Optional[int] = None,
40
+ width: Optional[int] = None,
41
+ num_inference_steps: int = 50,
42
+ timesteps: List[int] = None,
43
+ guidance_scale: float = 7.5,
44
+ negative_prompt: Optional[Union[str, List[str]]] = None,
45
+ num_images_per_prompt: Optional[int] = 1,
46
+ eta: float = 0.0,
47
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
48
+ latents: Optional[torch.FloatTensor] = None,
49
+ prompt_embeds: Optional[torch.FloatTensor] = None,
50
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
51
+ ip_adapter_image: Optional[PipelineImageInput] = None,
52
+ output_type: Optional[str] = "pil",
53
+ return_dict: bool = True,
54
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
55
+ guidance_rescale: float = 0.0,
56
+ clip_skip: Optional[int] = None,
57
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
58
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
59
+ model_id: int = 0,
60
+ enable_model_id: bool = True,
61
+ **kwargs,
62
+ ):
63
+ r"""
64
+ The call function to the pipeline for generation.
65
+
66
+ Args:
67
+ prompt (`str` or `List[str]`, *optional*):
68
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
69
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
70
+ The height in pixels of the generated image.
71
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
72
+ The width in pixels of the generated image.
73
+ num_inference_steps (`int`, *optional*, defaults to 50):
74
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
75
+ expense of slower inference.
76
+ timesteps (`List[int]`, *optional*):
77
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
78
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
79
+ passed will be used. Must be in descending order.
80
+ guidance_scale (`float`, *optional*, defaults to 7.5):
81
+ A higher guidance scale value encourages the model to generate images closely linked to the text
82
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
83
+ negative_prompt (`str` or `List[str]`, *optional*):
84
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
85
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
86
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
87
+ The number of images to generate per prompt.
88
+ eta (`float`, *optional*, defaults to 0.0):
89
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
90
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
91
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
92
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
93
+ generation deterministic.
94
+ latents (`torch.FloatTensor`, *optional*):
95
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
96
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
97
+ tensor is generated by sampling using the supplied random `generator`.
98
+ prompt_embeds (`torch.FloatTensor`, *optional*):
99
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
100
+ provided, text embeddings are generated from the `prompt` input argument.
101
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
102
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
103
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
104
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
105
+ output_type (`str`, *optional*, defaults to `"pil"`):
106
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
107
+ return_dict (`bool`, *optional*, defaults to `True`):
108
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
109
+ plain tuple.
110
+ cross_attention_kwargs (`dict`, *optional*):
111
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
112
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
113
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
114
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
115
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
116
+ using zero terminal SNR.
117
+ clip_skip (`int`, *optional*):
118
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
119
+ the output of the pre-final layer will be used for computing the prompt embeddings.
120
+ callback_on_step_end (`Callable`, *optional*):
121
+ A function that calls at the end of each denoising steps during the inference. The function is called
122
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
123
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
124
+ `callback_on_step_end_tensor_inputs`.
125
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
126
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
127
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
128
+ `._callback_tensor_inputs` attribute of your pipeline class.
129
+
130
+ Examples:
131
+
132
+ Returns:
133
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
134
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
135
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
136
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
137
+ "not-safe-for-work" (nsfw) content.
138
+ """
139
+
140
+ callback = kwargs.pop("callback", None)
141
+ callback_steps = kwargs.pop("callback_steps", None)
142
+
143
+ if callback is not None:
144
+ deprecate(
145
+ "callback",
146
+ "1.0.0",
147
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
148
+ )
149
+ if callback_steps is not None:
150
+ deprecate(
151
+ "callback_steps",
152
+ "1.0.0",
153
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
154
+ )
155
+
156
+ # 0. Default height and width to unet
157
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
158
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
159
+ # to deal with lora scaling and other possible forward hooks
160
+
161
+ # 1. Check inputs. Raise error if not correct
162
+ self.check_inputs(
163
+ prompt,
164
+ height,
165
+ width,
166
+ callback_steps,
167
+ negative_prompt,
168
+ prompt_embeds,
169
+ negative_prompt_embeds,
170
+ callback_on_step_end_tensor_inputs,
171
+ )
172
+
173
+ self._guidance_scale = guidance_scale
174
+ self._guidance_rescale = guidance_rescale
175
+ self._clip_skip = clip_skip
176
+ self._cross_attention_kwargs = cross_attention_kwargs
177
+
178
+ # 2. Define call parameters
179
+ if prompt is not None and isinstance(prompt, str):
180
+ batch_size = 1
181
+ elif prompt is not None and isinstance(prompt, list):
182
+ batch_size = len(prompt)
183
+ else:
184
+ batch_size = prompt_embeds.shape[0]
185
+
186
+ device = self._execution_device
187
+
188
+ # 3. Encode input prompt
189
+ lora_scale = (
190
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
191
+ )
192
+
193
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
194
+ prompt,
195
+ device,
196
+ num_images_per_prompt,
197
+ self.do_classifier_free_guidance,
198
+ negative_prompt,
199
+ prompt_embeds=prompt_embeds,
200
+ negative_prompt_embeds=negative_prompt_embeds,
201
+ lora_scale=lora_scale,
202
+ clip_skip=self.clip_skip,
203
+ )
204
+
205
+ # For classifier free guidance, we need to do two forward passes.
206
+ # Here we concatenate the unconditional and text embeddings into a single batch
207
+ # to avoid doing two forward passes
208
+ if self.do_classifier_free_guidance:
209
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
210
+
211
+ if ip_adapter_image is not None:
212
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
213
+ if self.do_classifier_free_guidance:
214
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
215
+
216
+ # 4. Prepare timesteps
217
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
218
+
219
+ # 5. Prepare latent variables
220
+ num_channels_latents = self.unet.config.in_channels
221
+ latents = self.prepare_latents(
222
+ batch_size * num_images_per_prompt,
223
+ num_channels_latents,
224
+ height,
225
+ width,
226
+ prompt_embeds.dtype,
227
+ device,
228
+ generator,
229
+ latents,
230
+ )
231
+
232
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
233
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
234
+
235
+ # 6.1 Add image embeds for IP-Adapter
236
+ ipadapter_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
237
+
238
+ # 6.2 Optionally get Guidance Scale Embedding
239
+ timestep_cond = None
240
+ if self.unet.config.time_cond_proj_dim is not None:
241
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
242
+ timestep_cond = self.get_guidance_scale_embedding(
243
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
244
+ ).to(device=device, dtype=latents.dtype)
245
+
246
+ # 6.3 Add model_ids
247
+ assert 0 <= model_id and model_id < self.unet.model_embedding.num_embeddings
248
+ model_ids = torch.LongTensor([model_id] * len(latents) * (2 if self.do_classifier_free_guidance else 1)).to(device) # (b,)
249
+ added_cond_kwargs = {"model_ids": model_ids}
250
+ if ipadapter_cond_kwargs is not None:
251
+ added_cond_kwargs.update(ipadapter_cond_kwargs)
252
+ # print(added_cond_kwargs)
253
+
254
+ # 7. Denoising loop
255
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
256
+ self._num_timesteps = len(timesteps)
257
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
258
+ for i, t in enumerate(timesteps):
259
+ # expand the latents if we are doing classifier free guidance
260
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
261
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
262
+
263
+ # predict the noise residual
264
+ noise_pred = self.unet(
265
+ latent_model_input,
266
+ t,
267
+ encoder_hidden_states=prompt_embeds,
268
+ timestep_cond=timestep_cond,
269
+ cross_attention_kwargs=self.cross_attention_kwargs,
270
+ added_cond_kwargs=added_cond_kwargs,
271
+ return_dict=False,
272
+ enable_model_id=enable_model_id,
273
+ )[0]
274
+
275
+ # perform guidance
276
+ if self.do_classifier_free_guidance:
277
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
278
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
279
+
280
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
281
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
282
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
283
+
284
+ # compute the previous noisy sample x_t -> x_t-1
285
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
286
+
287
+ if callback_on_step_end is not None:
288
+ callback_kwargs = {}
289
+ for k in callback_on_step_end_tensor_inputs:
290
+ callback_kwargs[k] = locals()[k]
291
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
292
+
293
+ latents = callback_outputs.pop("latents", latents)
294
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
295
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
296
+
297
+ # call the callback, if provided
298
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
299
+ progress_bar.update()
300
+ if callback is not None and i % callback_steps == 0:
301
+ step_idx = i // getattr(self.scheduler, "order", 1)
302
+ callback(step_idx, t, latents)
303
+
304
+ if not output_type == "latent":
305
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
306
+ 0
307
+ ]
308
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
309
+ else:
310
+ image = latents
311
+ has_nsfw_concept = None
312
+
313
+ if has_nsfw_concept is None:
314
+ do_denormalize = [True] * image.shape[0]
315
+ else:
316
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
317
+
318
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
319
+
320
+ # Offload all models
321
+ self.maybe_free_model_hooks()
322
+
323
+ if not return_dict:
324
+ return (image, has_nsfw_concept)
325
+
326
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
modeling/dmm_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ fsspec
3
+ Pillow==9.4.0
4
+ torch==2.2
5
+ accelerate
6
+ transformers
7
+ diffusers
8
+ Jinja2==3.1.4
9
+ huggingface-hub
10
+ retrying
11
+ setuptools>=40.8.0
12
+ open_clip_torch==2.29.0
13
+ gradio