Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -171,59 +171,37 @@ ASPECT_RATIOS = {
|
|
| 171 |
|
| 172 |
def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
|
| 173 |
"""
|
| 174 |
-
|
| 175 |
-
|
| 176 |
"""
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
cache.append(torch.zeros(
|
| 206 |
-
1,
|
| 207 |
-
1,
|
| 208 |
-
256,
|
| 209 |
-
latent_h // 2,
|
| 210 |
-
latent_w // 2,
|
| 211 |
-
device=device,
|
| 212 |
-
dtype=dtype
|
| 213 |
-
))
|
| 214 |
-
# 第四级特征,channels=128,不下采样
|
| 215 |
-
cache.append(torch.zeros(
|
| 216 |
-
1,
|
| 217 |
-
1,
|
| 218 |
-
128,
|
| 219 |
-
latent_h,
|
| 220 |
-
latent_w,
|
| 221 |
-
device=device,
|
| 222 |
-
dtype=dtype
|
| 223 |
-
))
|
| 224 |
-
|
| 225 |
-
return cache
|
| 226 |
-
|
| 227 |
|
| 228 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
| 229 |
"""
|
|
@@ -416,14 +394,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, aspect_ratio="16
|
|
| 416 |
|
| 417 |
vae_cache, latents_cache = None, None
|
| 418 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
| 419 |
-
#
|
| 420 |
-
|
| 421 |
-
if aspect_ratio == "16:9":
|
| 422 |
-
# Use default cache for 16:9
|
| 423 |
-
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
| 424 |
-
else:
|
| 425 |
-
# Create custom cache for 9:16
|
| 426 |
-
vae_cache = get_vae_cache_for_aspect_ratio(aspect_ratio, gpu, torch.float16)
|
| 427 |
|
| 428 |
num_blocks = 7
|
| 429 |
current_start_frame = 0
|
|
|
|
| 171 |
|
| 172 |
def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
|
| 173 |
"""
|
| 174 |
+
Create VAE cache with appropriate dimensions for the given aspect ratio.
|
| 175 |
+
Based on the structure of ZERO_VAE_CACHE but adjusted for different aspect ratios.
|
| 176 |
"""
|
| 177 |
+
# First, let's check the structure of ZERO_VAE_CACHE to understand the format
|
| 178 |
+
print(f"Creating VAE cache for {aspect_ratio}")
|
| 179 |
+
|
| 180 |
+
# For 9:16, we need to swap the height and width dimensions from the 16:9 default
|
| 181 |
+
if aspect_ratio == "9:16":
|
| 182 |
+
# The cache structure from ZERO_VAE_CACHE appears to be feature maps at different scales
|
| 183 |
+
# We need to maintain the same structure but swap H and W dimensions
|
| 184 |
+
cache = []
|
| 185 |
+
for i, tensor in enumerate(ZERO_VAE_CACHE):
|
| 186 |
+
# Get the original shape
|
| 187 |
+
original_shape = list(tensor.shape)
|
| 188 |
+
print(f"Original cache tensor {i} shape: {original_shape}")
|
| 189 |
+
|
| 190 |
+
# For 9:16, we swap the last two dimensions (H and W)
|
| 191 |
+
if len(original_shape) == 5: # (B, C, T, H, W)
|
| 192 |
+
new_shape = original_shape.copy()
|
| 193 |
+
new_shape[-2], new_shape[-1] = original_shape[-1], original_shape[-2] # Swap H and W
|
| 194 |
+
new_tensor = torch.zeros(new_shape, device=device, dtype=dtype)
|
| 195 |
+
cache.append(new_tensor)
|
| 196 |
+
print(f"New cache tensor {i} shape: {new_shape}")
|
| 197 |
+
else:
|
| 198 |
+
# If not 5D, just copy as is
|
| 199 |
+
cache.append(tensor.to(device=device, dtype=dtype))
|
| 200 |
+
|
| 201 |
+
return cache
|
| 202 |
+
else:
|
| 203 |
+
# For 16:9, use the default cache
|
| 204 |
+
return [c.to(device=device, dtype=dtype) for c in ZERO_VAE_CACHE]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
| 207 |
"""
|
|
|
|
| 394 |
|
| 395 |
vae_cache, latents_cache = None, None
|
| 396 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
| 397 |
+
# Create VAE cache appropriate for the aspect ratio
|
| 398 |
+
vae_cache = get_vae_cache_for_aspect_ratio(aspect_ratio, gpu, torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
num_blocks = 7
|
| 401 |
current_start_frame = 0
|