Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -12,9 +12,8 @@ import gradio as gr
|
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import wd14tagger
|
15 |
-
import memory_management
|
16 |
import uuid
|
17 |
-
|
18 |
from PIL import Image
|
19 |
from diffusers_helper.code_cond import unet_add_coded_conds
|
20 |
from diffusers_helper.cat_cond import unet_add_concat_conds
|
@@ -24,8 +23,11 @@ from diffusers.models.attention_processor import AttnProcessor2_0
|
|
24 |
from transformers import CLIPTextModel, CLIPTokenizer
|
25 |
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
|
26 |
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
|
|
|
|
|
|
27 |
torch.set_grad_enabled(False)
|
28 |
-
|
29 |
class ModifiedUNet(UNet2DConditionModel):
|
30 |
@classmethod
|
31 |
def from_config(cls, *args, **kwargs):
|
@@ -37,9 +39,9 @@ class ModifiedUNet(UNet2DConditionModel):
|
|
37 |
|
38 |
model_name = 'lllyasviel/paints_undo_single_frame'
|
39 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
40 |
-
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
|
41 |
-
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
|
42 |
-
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
|
43 |
|
44 |
unet.set_attn_processor(AttnProcessor2_0())
|
45 |
vae.set_attn_processor(AttnProcessor2_0())
|
@@ -47,12 +49,7 @@ vae.set_attn_processor(AttnProcessor2_0())
|
|
47 |
video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
|
48 |
'lllyasviel/paints_undo_multi_frame',
|
49 |
fp16=True
|
50 |
-
)
|
51 |
-
|
52 |
-
memory_management.unload_all_models([
|
53 |
-
video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
|
54 |
-
unet, vae, text_encoder
|
55 |
-
])
|
56 |
|
57 |
k_sampler = KDiffusionSampler(
|
58 |
unet=unet,
|
@@ -73,18 +70,17 @@ def find_best_bucket(h, w, options):
|
|
73 |
best_bucket = (bucket_h, bucket_w)
|
74 |
return best_bucket
|
75 |
|
76 |
-
|
77 |
def encode_cropped_prompt_77tokens(txt: str):
|
78 |
-
memory_management.load_models_to_gpu(text_encoder)
|
79 |
cond_ids = tokenizer(txt,
|
80 |
padding="max_length",
|
81 |
max_length=tokenizer.model_max_length,
|
82 |
truncation=True,
|
83 |
-
return_tensors="pt").input_ids.to(device=
|
84 |
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
85 |
return text_cond
|
86 |
|
87 |
-
|
88 |
def pytorch2numpy(imgs):
|
89 |
results = []
|
90 |
for x in imgs:
|
@@ -94,7 +90,7 @@ def pytorch2numpy(imgs):
|
|
94 |
results.append(y)
|
95 |
return results
|
96 |
|
97 |
-
|
98 |
def numpy2pytorch(imgs):
|
99 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
100 |
h = h.movedim(-1, 1)
|
@@ -106,28 +102,28 @@ def resize_without_crop(image, target_width, target_height):
|
|
106 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
107 |
return np.array(resized_image)
|
108 |
|
|
|
109 |
@spaces.GPU()
|
110 |
def interrogator_process(x):
|
111 |
-
|
|
|
|
|
112 |
|
113 |
@spaces.GPU()
|
114 |
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
115 |
progress=gr.Progress()):
|
116 |
-
rng = torch.Generator(device=
|
117 |
|
118 |
-
memory_management.load_models_to_gpu(vae)
|
119 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
120 |
-
concat_conds = numpy2pytorch([fg]).to(device=
|
121 |
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
122 |
|
123 |
-
memory_management.load_models_to_gpu(text_encoder)
|
124 |
conds = encode_cropped_prompt_77tokens(prompt)
|
125 |
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
126 |
|
127 |
-
|
128 |
-
fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
|
129 |
initial_latents = torch.zeros_like(concat_conds)
|
130 |
-
concat_conds = concat_conds.to(device=
|
131 |
latents = k_sampler(
|
132 |
initial_latent=initial_latents,
|
133 |
strength=1.0,
|
@@ -142,14 +138,13 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
|
|
142 |
progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
|
143 |
).to(vae.dtype) / vae.config.scaling_factor
|
144 |
|
145 |
-
memory_management.load_models_to_gpu(vae)
|
146 |
pixels = vae.decode(latents).sample
|
147 |
pixels = pytorch2numpy(pixels)
|
148 |
pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
|
149 |
|
150 |
return pixels
|
151 |
|
152 |
-
|
153 |
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
154 |
random.seed(seed)
|
155 |
np.random.seed(seed)
|
@@ -168,25 +163,21 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
|
|
168 |
input_frames = numpy2pytorch([image_1, image_2])
|
169 |
input_frames = input_frames.unsqueeze(0).movedim(1, 2)
|
170 |
|
171 |
-
memory_management.load_models_to_gpu(video_pipe.text_encoder)
|
172 |
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
173 |
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
174 |
|
175 |
-
|
176 |
-
input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
|
177 |
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
178 |
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
179 |
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
180 |
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
181 |
|
182 |
-
|
183 |
-
input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
|
184 |
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
185 |
first_frame = input_frame_latents[:, :, 0]
|
186 |
last_frame = input_frame_latents[:, :, 1]
|
187 |
concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
|
188 |
|
189 |
-
memory_management.load_models_to_gpu([video_pipe.unet])
|
190 |
latents = video_pipe(
|
191 |
batch_size=1,
|
192 |
steps=int(steps),
|
@@ -200,11 +191,11 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
|
|
200 |
progress_tqdm=progress_tqdm
|
201 |
)
|
202 |
|
203 |
-
memory_management.load_models_to_gpu([video_pipe.vae])
|
204 |
video = video_pipe.decode_latents(latents, vae_hidden_states)
|
205 |
return video, image_1, image_2
|
206 |
|
207 |
-
|
|
|
208 |
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
209 |
result_frames = []
|
210 |
cropped_images = []
|
@@ -282,7 +273,7 @@ with block:
|
|
282 |
prompt_gen_button.click(
|
283 |
fn=interrogator_process,
|
284 |
inputs=[input_fg],
|
285 |
-
outputs=[prompt]
|
286 |
).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
|
287 |
outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
|
288 |
|
@@ -311,4 +302,4 @@ with block:
|
|
311 |
examples_per_page=1024
|
312 |
)
|
313 |
|
314 |
-
block.queue().launch(server_name='0.0.0.0')
|
|
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import wd14tagger
|
|
|
15 |
import uuid
|
16 |
+
|
17 |
from PIL import Image
|
18 |
from diffusers_helper.code_cond import unet_add_coded_conds
|
19 |
from diffusers_helper.cat_cond import unet_add_concat_conds
|
|
|
23 |
from transformers import CLIPTextModel, CLIPTokenizer
|
24 |
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
|
25 |
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
26 |
+
import spaces
|
27 |
+
|
28 |
+
# Disable gradients globally
|
29 |
torch.set_grad_enabled(False)
|
30 |
+
|
31 |
class ModifiedUNet(UNet2DConditionModel):
|
32 |
@classmethod
|
33 |
def from_config(cls, *args, **kwargs):
|
|
|
39 |
|
40 |
model_name = 'lllyasviel/paints_undo_single_frame'
|
41 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
42 |
+
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16).to("cuda")
|
43 |
+
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16).to("cuda") # bfloat16 vae
|
44 |
+
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16).to("cuda")
|
45 |
|
46 |
unet.set_attn_processor(AttnProcessor2_0())
|
47 |
vae.set_attn_processor(AttnProcessor2_0())
|
|
|
49 |
video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
|
50 |
'lllyasviel/paints_undo_multi_frame',
|
51 |
fp16=True
|
52 |
+
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
k_sampler = KDiffusionSampler(
|
55 |
unet=unet,
|
|
|
70 |
best_bucket = (bucket_h, bucket_w)
|
71 |
return best_bucket
|
72 |
|
73 |
+
|
74 |
def encode_cropped_prompt_77tokens(txt: str):
|
|
|
75 |
cond_ids = tokenizer(txt,
|
76 |
padding="max_length",
|
77 |
max_length=tokenizer.model_max_length,
|
78 |
truncation=True,
|
79 |
+
return_tensors="pt").input_ids.to(device="cuda")
|
80 |
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
81 |
return text_cond
|
82 |
|
83 |
+
|
84 |
def pytorch2numpy(imgs):
|
85 |
results = []
|
86 |
for x in imgs:
|
|
|
90 |
results.append(y)
|
91 |
return results
|
92 |
|
93 |
+
|
94 |
def numpy2pytorch(imgs):
|
95 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
96 |
h = h.movedim(-1, 1)
|
|
|
102 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
103 |
return np.array(resized_image)
|
104 |
|
105 |
+
|
106 |
@spaces.GPU()
|
107 |
def interrogator_process(x):
|
108 |
+
image_description = wd14tagger.default_interrogator(x)
|
109 |
+
return image_description, image_description
|
110 |
+
|
111 |
|
112 |
@spaces.GPU()
|
113 |
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
114 |
progress=gr.Progress()):
|
115 |
+
rng = torch.Generator(device="cuda").manual_seed(int(seed))
|
116 |
|
|
|
117 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
118 |
+
concat_conds = numpy2pytorch([fg]).clone().detach().to(device="cuda", dtype=vae.dtype)
|
119 |
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
120 |
|
|
|
121 |
conds = encode_cropped_prompt_77tokens(prompt)
|
122 |
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
123 |
|
124 |
+
fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long)
|
|
|
125 |
initial_latents = torch.zeros_like(concat_conds)
|
126 |
+
concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype)
|
127 |
latents = k_sampler(
|
128 |
initial_latent=initial_latents,
|
129 |
strength=1.0,
|
|
|
138 |
progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
|
139 |
).to(vae.dtype) / vae.config.scaling_factor
|
140 |
|
|
|
141 |
pixels = vae.decode(latents).sample
|
142 |
pixels = pytorch2numpy(pixels)
|
143 |
pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
|
144 |
|
145 |
return pixels
|
146 |
|
147 |
+
|
148 |
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
149 |
random.seed(seed)
|
150 |
np.random.seed(seed)
|
|
|
163 |
input_frames = numpy2pytorch([image_1, image_2])
|
164 |
input_frames = input_frames.unsqueeze(0).movedim(1, 2)
|
165 |
|
|
|
166 |
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
167 |
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
168 |
|
169 |
+
input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype)
|
|
|
170 |
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
171 |
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
172 |
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
173 |
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
174 |
|
175 |
+
input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype)
|
|
|
176 |
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
177 |
first_frame = input_frame_latents[:, :, 0]
|
178 |
last_frame = input_frame_latents[:, :, 1]
|
179 |
concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
|
180 |
|
|
|
181 |
latents = video_pipe(
|
182 |
batch_size=1,
|
183 |
steps=int(steps),
|
|
|
191 |
progress_tqdm=progress_tqdm
|
192 |
)
|
193 |
|
|
|
194 |
video = video_pipe.decode_latents(latents, vae_hidden_states)
|
195 |
return video, image_1, image_2
|
196 |
|
197 |
+
|
198 |
+
@spaces.GPU(duration=360)
|
199 |
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
200 |
result_frames = []
|
201 |
cropped_images = []
|
|
|
273 |
prompt_gen_button.click(
|
274 |
fn=interrogator_process,
|
275 |
inputs=[input_fg],
|
276 |
+
outputs=[prompt, i2v_input_text]
|
277 |
).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
|
278 |
outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
|
279 |
|
|
|
302 |
examples_per_page=1024
|
303 |
)
|
304 |
|
305 |
+
block.queue().launch(server_name='0.0.0.0')
|