kadirnar commited on
Commit
6f2d84f
·
verified ·
1 Parent(s): b4f278b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -37
app.py CHANGED
@@ -12,9 +12,8 @@ import gradio as gr
12
  import numpy as np
13
  import torch
14
  import wd14tagger
15
- import memory_management
16
  import uuid
17
- import spaces
18
  from PIL import Image
19
  from diffusers_helper.code_cond import unet_add_coded_conds
20
  from diffusers_helper.cat_cond import unet_add_concat_conds
@@ -24,8 +23,11 @@ from diffusers.models.attention_processor import AttnProcessor2_0
24
  from transformers import CLIPTextModel, CLIPTokenizer
25
  from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
26
  from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
 
 
 
27
  torch.set_grad_enabled(False)
28
- @spaces.GPU()
29
  class ModifiedUNet(UNet2DConditionModel):
30
  @classmethod
31
  def from_config(cls, *args, **kwargs):
@@ -37,9 +39,9 @@ class ModifiedUNet(UNet2DConditionModel):
37
 
38
  model_name = 'lllyasviel/paints_undo_single_frame'
39
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
40
- text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
41
- vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
42
- unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
43
 
44
  unet.set_attn_processor(AttnProcessor2_0())
45
  vae.set_attn_processor(AttnProcessor2_0())
@@ -47,12 +49,7 @@ vae.set_attn_processor(AttnProcessor2_0())
47
  video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
48
  'lllyasviel/paints_undo_multi_frame',
49
  fp16=True
50
- )
51
-
52
- memory_management.unload_all_models([
53
- video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
54
- unet, vae, text_encoder
55
- ])
56
 
57
  k_sampler = KDiffusionSampler(
58
  unet=unet,
@@ -73,18 +70,17 @@ def find_best_bucket(h, w, options):
73
  best_bucket = (bucket_h, bucket_w)
74
  return best_bucket
75
 
76
- @spaces.GPU()
77
  def encode_cropped_prompt_77tokens(txt: str):
78
- memory_management.load_models_to_gpu(text_encoder)
79
  cond_ids = tokenizer(txt,
80
  padding="max_length",
81
  max_length=tokenizer.model_max_length,
82
  truncation=True,
83
- return_tensors="pt").input_ids.to(device=text_encoder.device)
84
  text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
85
  return text_cond
86
 
87
- @spaces.GPU()
88
  def pytorch2numpy(imgs):
89
  results = []
90
  for x in imgs:
@@ -94,7 +90,7 @@ def pytorch2numpy(imgs):
94
  results.append(y)
95
  return results
96
 
97
- @spaces.GPU()
98
  def numpy2pytorch(imgs):
99
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
100
  h = h.movedim(-1, 1)
@@ -106,28 +102,28 @@ def resize_without_crop(image, target_width, target_height):
106
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
107
  return np.array(resized_image)
108
 
 
109
  @spaces.GPU()
110
  def interrogator_process(x):
111
- return wd14tagger.default_interrogator(x)
 
 
112
 
113
  @spaces.GPU()
114
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
115
  progress=gr.Progress()):
116
- rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed))
117
 
118
- memory_management.load_models_to_gpu(vae)
119
  fg = resize_and_center_crop(input_fg, image_width, image_height)
120
- concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
121
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
122
 
123
- memory_management.load_models_to_gpu(text_encoder)
124
  conds = encode_cropped_prompt_77tokens(prompt)
125
  unconds = encode_cropped_prompt_77tokens(n_prompt)
126
 
127
- memory_management.load_models_to_gpu(unet)
128
- fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
129
  initial_latents = torch.zeros_like(concat_conds)
130
- concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
131
  latents = k_sampler(
132
  initial_latent=initial_latents,
133
  strength=1.0,
@@ -142,14 +138,13 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
142
  progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
143
  ).to(vae.dtype) / vae.config.scaling_factor
144
 
145
- memory_management.load_models_to_gpu(vae)
146
  pixels = vae.decode(latents).sample
147
  pixels = pytorch2numpy(pixels)
148
  pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
149
 
150
  return pixels
151
 
152
- @spaces.GPU()
153
  def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
154
  random.seed(seed)
155
  np.random.seed(seed)
@@ -168,25 +163,21 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
168
  input_frames = numpy2pytorch([image_1, image_2])
169
  input_frames = input_frames.unsqueeze(0).movedim(1, 2)
170
 
171
- memory_management.load_models_to_gpu(video_pipe.text_encoder)
172
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
173
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
174
 
175
- memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder])
176
- input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
177
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
178
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
179
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
180
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
181
 
182
- memory_management.load_models_to_gpu([video_pipe.vae])
183
- input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
184
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
185
  first_frame = input_frame_latents[:, :, 0]
186
  last_frame = input_frame_latents[:, :, 1]
187
  concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
188
 
189
- memory_management.load_models_to_gpu([video_pipe.unet])
190
  latents = video_pipe(
191
  batch_size=1,
192
  steps=int(steps),
@@ -200,11 +191,11 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
200
  progress_tqdm=progress_tqdm
201
  )
202
 
203
- memory_management.load_models_to_gpu([video_pipe.vae])
204
  video = video_pipe.decode_latents(latents, vae_hidden_states)
205
  return video, image_1, image_2
206
 
207
- @spaces.GPU
 
208
  def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
209
  result_frames = []
210
  cropped_images = []
@@ -282,7 +273,7 @@ with block:
282
  prompt_gen_button.click(
283
  fn=interrogator_process,
284
  inputs=[input_fg],
285
- outputs=[prompt]
286
  ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
287
  outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
288
 
@@ -311,4 +302,4 @@ with block:
311
  examples_per_page=1024
312
  )
313
 
314
- block.queue().launch(server_name='0.0.0.0')
 
12
  import numpy as np
13
  import torch
14
  import wd14tagger
 
15
  import uuid
16
+
17
  from PIL import Image
18
  from diffusers_helper.code_cond import unet_add_coded_conds
19
  from diffusers_helper.cat_cond import unet_add_concat_conds
 
23
  from transformers import CLIPTextModel, CLIPTokenizer
24
  from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
25
  from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
26
+ import spaces
27
+
28
+ # Disable gradients globally
29
  torch.set_grad_enabled(False)
30
+
31
  class ModifiedUNet(UNet2DConditionModel):
32
  @classmethod
33
  def from_config(cls, *args, **kwargs):
 
39
 
40
  model_name = 'lllyasviel/paints_undo_single_frame'
41
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
42
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16).to("cuda")
43
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16).to("cuda") # bfloat16 vae
44
+ unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16).to("cuda")
45
 
46
  unet.set_attn_processor(AttnProcessor2_0())
47
  vae.set_attn_processor(AttnProcessor2_0())
 
49
  video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
50
  'lllyasviel/paints_undo_multi_frame',
51
  fp16=True
52
+ ).to("cuda")
 
 
 
 
 
53
 
54
  k_sampler = KDiffusionSampler(
55
  unet=unet,
 
70
  best_bucket = (bucket_h, bucket_w)
71
  return best_bucket
72
 
73
+
74
  def encode_cropped_prompt_77tokens(txt: str):
 
75
  cond_ids = tokenizer(txt,
76
  padding="max_length",
77
  max_length=tokenizer.model_max_length,
78
  truncation=True,
79
+ return_tensors="pt").input_ids.to(device="cuda")
80
  text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
81
  return text_cond
82
 
83
+
84
  def pytorch2numpy(imgs):
85
  results = []
86
  for x in imgs:
 
90
  results.append(y)
91
  return results
92
 
93
+
94
  def numpy2pytorch(imgs):
95
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
96
  h = h.movedim(-1, 1)
 
102
  resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
103
  return np.array(resized_image)
104
 
105
+
106
  @spaces.GPU()
107
  def interrogator_process(x):
108
+ image_description = wd14tagger.default_interrogator(x)
109
+ return image_description, image_description
110
+
111
 
112
  @spaces.GPU()
113
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
114
  progress=gr.Progress()):
115
+ rng = torch.Generator(device="cuda").manual_seed(int(seed))
116
 
 
117
  fg = resize_and_center_crop(input_fg, image_width, image_height)
118
+ concat_conds = numpy2pytorch([fg]).clone().detach().to(device="cuda", dtype=vae.dtype)
119
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
120
 
 
121
  conds = encode_cropped_prompt_77tokens(prompt)
122
  unconds = encode_cropped_prompt_77tokens(n_prompt)
123
 
124
+ fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long)
 
125
  initial_latents = torch.zeros_like(concat_conds)
126
+ concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype)
127
  latents = k_sampler(
128
  initial_latent=initial_latents,
129
  strength=1.0,
 
138
  progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
139
  ).to(vae.dtype) / vae.config.scaling_factor
140
 
 
141
  pixels = vae.decode(latents).sample
142
  pixels = pytorch2numpy(pixels)
143
  pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
144
 
145
  return pixels
146
 
147
+
148
  def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
149
  random.seed(seed)
150
  np.random.seed(seed)
 
163
  input_frames = numpy2pytorch([image_1, image_2])
164
  input_frames = input_frames.unsqueeze(0).movedim(1, 2)
165
 
 
166
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
167
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
168
 
169
+ input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype)
 
170
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
171
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
172
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
173
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
174
 
175
+ input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype)
 
176
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
177
  first_frame = input_frame_latents[:, :, 0]
178
  last_frame = input_frame_latents[:, :, 1]
179
  concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
180
 
 
181
  latents = video_pipe(
182
  batch_size=1,
183
  steps=int(steps),
 
191
  progress_tqdm=progress_tqdm
192
  )
193
 
 
194
  video = video_pipe.decode_latents(latents, vae_hidden_states)
195
  return video, image_1, image_2
196
 
197
+
198
+ @spaces.GPU(duration=360)
199
  def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
200
  result_frames = []
201
  cropped_images = []
 
273
  prompt_gen_button.click(
274
  fn=interrogator_process,
275
  inputs=[input_fg],
276
+ outputs=[prompt, i2v_input_text]
277
  ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
278
  outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
279
 
 
302
  examples_per_page=1024
303
  )
304
 
305
+ block.queue().launch(server_name='0.0.0.0')