Update app.py
Browse files
app.py
CHANGED
@@ -42,18 +42,16 @@ prepare_latents_original = pipe.prepare_latents
|
|
42 |
|
43 |
# pipe.prepare_latents = prepare_latents_loop
|
44 |
|
45 |
-
# with the shift it will become step=0 0,1,2,3 -> step=1 1,2,3,0 -> step=2 2,3,0,1 -> step=3 3,0,1,2 -> step=4 0,1,2,3
|
46 |
-
# so we only shift (N_FRAME-1)//8+1 times
|
47 |
-
|
48 |
def modify_latents_callback(pipeline, step, timestep, callback_kwargs):
|
49 |
print("Rolling latents on step", step)
|
50 |
latents = callback_kwargs.get("latents")
|
51 |
unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
|
|
|
52 |
modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2)
|
53 |
modified_latents = pipeline._pack_latents(modified_latents)
|
54 |
return {"latents": modified_latents}
|
55 |
|
56 |
-
@spaces.GPU(duration=
|
57 |
def generate_gif(prompt, use_fixed_seed):
|
58 |
seed = 0 if use_fixed_seed else torch.seed()
|
59 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
|
|
42 |
|
43 |
# pipe.prepare_latents = prepare_latents_loop
|
44 |
|
|
|
|
|
|
|
45 |
def modify_latents_callback(pipeline, step, timestep, callback_kwargs):
|
46 |
print("Rolling latents on step", step)
|
47 |
latents = callback_kwargs.get("latents")
|
48 |
unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
|
49 |
+
# the frame order after each denoising step will be 0,1,2 -> 2,0,1 -> 1,2,0 -> 0,1,2 ...
|
50 |
modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2)
|
51 |
modified_latents = pipeline._pack_latents(modified_latents)
|
52 |
return {"latents": modified_latents}
|
53 |
|
54 |
+
@spaces.GPU(duration=140)
|
55 |
def generate_gif(prompt, use_fixed_seed):
|
56 |
seed = 0 if use_fixed_seed else torch.seed()
|
57 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|