Spaces:
Runtime error
Runtime error
daniel shalem
commited on
Commit
·
1940326
1
Parent(s):
645fba0
Feature: Add mixed precision support and direct bfloat16 support.
Browse files
xora/examples/image_to_video.py
CHANGED
|
@@ -136,6 +136,12 @@ def main():
|
|
| 136 |
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
| 137 |
)
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# Prompts
|
| 140 |
parser.add_argument(
|
| 141 |
"--prompt",
|
|
@@ -224,6 +230,7 @@ def main():
|
|
| 224 |
is_video=True,
|
| 225 |
vae_per_channel_normalize=True,
|
| 226 |
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
|
|
|
| 227 |
).images
|
| 228 |
|
| 229 |
# Save output video
|
|
|
|
| 136 |
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
| 137 |
)
|
| 138 |
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--mixed_precision",
|
| 141 |
+
action="store_true",
|
| 142 |
+
help="Mixed precision in float32 and bfloat16",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
# Prompts
|
| 146 |
parser.add_argument(
|
| 147 |
"--prompt",
|
|
|
|
| 230 |
is_video=True,
|
| 231 |
vae_per_channel_normalize=True,
|
| 232 |
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
| 233 |
+
mixed_precision=args.mixed_precision,
|
| 234 |
).images
|
| 235 |
|
| 236 |
# Save output video
|
xora/models/transformers/transformer3d.py
CHANGED
|
@@ -305,7 +305,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 305 |
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
| 306 |
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
| 307 |
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
| 308 |
-
return cos_freq, sin_freq
|
| 309 |
|
| 310 |
def forward(
|
| 311 |
self,
|
|
|
|
| 305 |
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
| 306 |
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
| 307 |
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
| 308 |
+
return cos_freq.to(dtype), sin_freq.to(dtype)
|
| 309 |
|
| 310 |
def forward(
|
| 311 |
self,
|
xora/pipelines/pipeline_xora_video.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 12 |
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
from diffusers.models import AutoencoderKL
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
@@ -758,6 +759,7 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
| 758 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 759 |
clean_caption: bool = True,
|
| 760 |
media_items: Optional[torch.FloatTensor] = None,
|
|
|
|
| 761 |
**kwargs,
|
| 762 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 763 |
"""
|
|
@@ -1006,16 +1008,22 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
| 1006 |
|
| 1007 |
if conditioning_mask is not None:
|
| 1008 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
# predict noise model_output
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
|
|
|
| 1019 |
|
| 1020 |
# perform guidance
|
| 1021 |
if do_classifier_free_guidance:
|
|
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
| 12 |
+
from contextlib import nullcontext
|
| 13 |
from diffusers.image_processor import VaeImageProcessor
|
| 14 |
from diffusers.models import AutoencoderKL
|
| 15 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 759 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 760 |
clean_caption: bool = True,
|
| 761 |
media_items: Optional[torch.FloatTensor] = None,
|
| 762 |
+
mixed_precision: bool = False,
|
| 763 |
**kwargs,
|
| 764 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 765 |
"""
|
|
|
|
| 1008 |
|
| 1009 |
if conditioning_mask is not None:
|
| 1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 1011 |
+
# Choose the appropriate context manager based on `mixed_precision`
|
| 1012 |
+
if mixed_precision:
|
| 1013 |
+
context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
|
| 1014 |
+
else:
|
| 1015 |
+
context_manager = nullcontext() # Dummy context manager
|
| 1016 |
|
| 1017 |
# predict noise model_output
|
| 1018 |
+
with context_manager:
|
| 1019 |
+
noise_pred = self.transformer(
|
| 1020 |
+
latent_model_input.to(self.transformer.dtype),
|
| 1021 |
+
indices_grid,
|
| 1022 |
+
encoder_hidden_states=prompt_embeds.to(self.transformer.dtype),
|
| 1023 |
+
encoder_attention_mask=prompt_attention_mask,
|
| 1024 |
+
timestep=current_timestep,
|
| 1025 |
+
return_dict=False,
|
| 1026 |
+
)[0]
|
| 1027 |
|
| 1028 |
# perform guidance
|
| 1029 |
if do_classifier_free_guidance:
|