alexnasa commited on
Commit
bb65ef0
·
verified ·
1 Parent(s): 3e48e28

Upload 42 files

Browse files
Files changed (43) hide show
  1. .gitattributes +6 -0
  2. LICENSE.txt +201 -0
  3. OmniAvatar/base.py +127 -0
  4. OmniAvatar/configs/__init__.py +0 -0
  5. OmniAvatar/configs/model_config.py +17 -0
  6. OmniAvatar/distributed/__init__.py +0 -0
  7. OmniAvatar/distributed/fsdp.py +43 -0
  8. OmniAvatar/distributed/xdit_context_parallel.py +134 -0
  9. OmniAvatar/models/audio_pack.py +40 -0
  10. OmniAvatar/models/model_manager.py +432 -0
  11. OmniAvatar/models/vsa_util.py +232 -0
  12. OmniAvatar/models/wan_video_dit.py +607 -0
  13. OmniAvatar/models/wan_video_text_encoder.py +269 -0
  14. OmniAvatar/models/wan_video_vae.py +807 -0
  15. OmniAvatar/models/wav2vec.py +208 -0
  16. OmniAvatar/prompters/__init__.py +1 -0
  17. OmniAvatar/prompters/base_prompter.py +70 -0
  18. OmniAvatar/prompters/wan_prompter.py +109 -0
  19. OmniAvatar/schedulers/flow_match.py +79 -0
  20. OmniAvatar/utils/args_config.py +123 -0
  21. OmniAvatar/utils/audio_preprocess.py +21 -0
  22. OmniAvatar/utils/io_utils.py +256 -0
  23. OmniAvatar/vram_management/__init__.py +1 -0
  24. OmniAvatar/vram_management/layers.py +95 -0
  25. OmniAvatar/wan_video.py +344 -0
  26. README.md +13 -12
  27. app.py +942 -0
  28. args_config.yaml +71 -0
  29. assets/logo-omniavatar.png +0 -0
  30. assets/material/pipeline.png +3 -0
  31. assets/material/teaser.png +3 -0
  32. configs/inference.yaml +37 -0
  33. configs/inference_1.3B.yaml +37 -0
  34. examples/audios/fox.wav +3 -0
  35. examples/audios/lion.wav +3 -0
  36. examples/audios/ocean.wav +3 -0
  37. examples/audios/script.wav +3 -0
  38. examples/images/female-002.png +0 -0
  39. examples/images/female-003.png +3 -0
  40. examples/images/female-009.png +0 -0
  41. examples/images/male-001.png +3 -0
  42. requirements.txt +18 -0
  43. scripts/inference.py +383 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/audios/fox.wav filter=lfs diff=lfs merge=lfs -text
37
+ examples/audios/lion.wav filter=lfs diff=lfs merge=lfs -text
38
+ examples/audios/ocean.wav filter=lfs diff=lfs merge=lfs -text
39
+ examples/audios/script.wav filter=lfs diff=lfs merge=lfs -text
40
+ examples/images/female-003.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/images/male-001.png filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
OmniAvatar/base.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision.transforms import GaussianBlur
5
+
6
+
7
+
8
+ class BasePipeline(torch.nn.Module):
9
+
10
+ def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
11
+ super().__init__()
12
+ self.device = device
13
+ self.torch_dtype = torch_dtype
14
+ self.height_division_factor = height_division_factor
15
+ self.width_division_factor = width_division_factor
16
+ self.cpu_offload = False
17
+ self.model_names = []
18
+
19
+
20
+ def check_resize_height_width(self, height, width):
21
+ if height % self.height_division_factor != 0:
22
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
23
+ print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
24
+ if width % self.width_division_factor != 0:
25
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
26
+ print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
27
+ return height, width
28
+
29
+
30
+ def preprocess_image(self, image):
31
+ image = torch.Tensor(np.array(image, dtype=np.float16) * (2.0 / 255) - 1.0).permute(2, 0, 1).unsqueeze(0)
32
+ return image
33
+
34
+
35
+ def preprocess_images(self, images):
36
+ return [self.preprocess_image(image) for image in images]
37
+
38
+
39
+ def vae_output_to_image(self, vae_output):
40
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
41
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
42
+ return image
43
+
44
+
45
+ def vae_output_to_video(self, vae_output):
46
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
47
+ video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
48
+ return video
49
+
50
+
51
+ def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
52
+ if len(latents) > 0:
53
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
54
+ height, width = value.shape[-2:]
55
+ weight = torch.ones_like(value)
56
+ for latent, mask, scale in zip(latents, masks, scales):
57
+ mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
58
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
59
+ mask = blur(mask)
60
+ value += latent * mask * scale
61
+ weight += mask * scale
62
+ value /= weight
63
+ return value
64
+
65
+
66
+ def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
67
+ if special_kwargs is None:
68
+ noise_pred_global = inference_callback(prompt_emb_global)
69
+ else:
70
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
71
+ if special_local_kwargs_list is None:
72
+ noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
73
+ else:
74
+ noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
75
+ noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
76
+ return noise_pred
77
+
78
+
79
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
80
+ local_prompts = local_prompts or []
81
+ masks = masks or []
82
+ mask_scales = mask_scales or []
83
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
84
+ prompt = extended_prompt_dict.get("prompt", prompt)
85
+ local_prompts += extended_prompt_dict.get("prompts", [])
86
+ masks += extended_prompt_dict.get("masks", [])
87
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
88
+ return prompt, local_prompts, masks, mask_scales
89
+
90
+
91
+ def enable_cpu_offload(self):
92
+ self.cpu_offload = True
93
+
94
+
95
+ def load_models_to_device(self, loadmodel_names=[]):
96
+ # only load models to device if cpu_offload is enabled
97
+ if not self.cpu_offload:
98
+ return
99
+ # offload the unneeded models to cpu
100
+ for model_name in self.model_names:
101
+ if model_name not in loadmodel_names:
102
+ model = getattr(self, model_name)
103
+ if model is not None:
104
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
105
+ for module in model.modules():
106
+ if hasattr(module, "offload"):
107
+ module.offload()
108
+ else:
109
+ model.cpu()
110
+ # load the needed models to device
111
+ for model_name in loadmodel_names:
112
+ model = getattr(self, model_name)
113
+ if model is not None:
114
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
115
+ for module in model.modules():
116
+ if hasattr(module, "onload"):
117
+ module.onload()
118
+ else:
119
+ model.to(self.device)
120
+ # fresh the cuda cache
121
+ torch.cuda.empty_cache()
122
+
123
+
124
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
125
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
126
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
127
+ return noise
OmniAvatar/configs/__init__.py ADDED
File without changes
OmniAvatar/configs/model_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+ from ..models.wan_video_dit import WanModel
3
+ from ..models.wan_video_text_encoder import WanTextEncoder
4
+ from ..models.wan_video_vae import WanVideoVAE
5
+
6
+
7
+ model_loader_configs = [
8
+ # These configs are provided for detecting model type automatically.
9
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
10
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
11
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
12
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
13
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
14
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
15
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
16
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
17
+ ]
OmniAvatar/distributed/__init__.py ADDED
File without changes
OmniAvatar/distributed/fsdp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(
13
+ model,
14
+ device_id,
15
+ param_dtype=torch.bfloat16,
16
+ reduce_dtype=torch.float32,
17
+ buffer_dtype=torch.float32,
18
+ process_group=None,
19
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
20
+ sync_module_states=True,
21
+ ):
22
+ model = FSDP(
23
+ module=model,
24
+ process_group=process_group,
25
+ sharding_strategy=sharding_strategy,
26
+ auto_wrap_policy=partial(
27
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
28
+ mixed_precision=MixedPrecision(
29
+ param_dtype=param_dtype,
30
+ reduce_dtype=reduce_dtype,
31
+ buffer_dtype=buffer_dtype),
32
+ device_id=device_id,
33
+ sync_module_states=sync_module_states)
34
+ return model
35
+
36
+
37
+ def free_model(model):
38
+ for m in model.modules():
39
+ if isinstance(m, FSDP):
40
+ _free_storage(m._handle.flat_param.data)
41
+ del model
42
+ gc.collect()
43
+ torch.cuda.empty_cache()
OmniAvatar/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from einops import rearrange
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+ from yunchang import LongContextAttention
9
+
10
+ def sinusoidal_embedding_1d(dim, position):
11
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
12
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
13
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
14
+ return x.to(position.dtype)
15
+
16
+ def pad_freqs(original_tensor, target_len):
17
+ seq_len, s1, s2 = original_tensor.shape
18
+ pad_size = target_len - seq_len
19
+ padding_tensor = torch.ones(
20
+ pad_size,
21
+ s1,
22
+ s2,
23
+ dtype=original_tensor.dtype,
24
+ device=original_tensor.device)
25
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
26
+ return padded_tensor
27
+
28
+ def rope_apply(x, freqs, num_heads):
29
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
30
+ s_per_rank = x.shape[1]
31
+ s_per_rank = get_sp_group().broadcast_object_list([s_per_rank], src=0)[0] # TODO: the size should be devided by sp_size
32
+
33
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
34
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
35
+
36
+ sp_size = get_sequence_parallel_world_size()
37
+ sp_rank = get_sequence_parallel_rank()
38
+ if freqs.shape[0] % sp_size != 0 and freqs.shape[0] // sp_size == s_per_rank:
39
+ s_per_rank = s_per_rank + 1
40
+ freqs = pad_freqs(freqs, s_per_rank * sp_size)
41
+ freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
42
+ freqs_rank = freqs_rank[:x.shape[1]]
43
+ x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
44
+ return x_out.to(x.dtype)
45
+
46
+ def usp_dit_forward(self,
47
+ x: torch.Tensor,
48
+ timestep: torch.Tensor,
49
+ context: torch.Tensor,
50
+ clip_feature: Optional[torch.Tensor] = None,
51
+ y: Optional[torch.Tensor] = None,
52
+ use_gradient_checkpointing: bool = False,
53
+ use_gradient_checkpointing_offload: bool = False,
54
+ **kwargs,
55
+ ):
56
+ t = self.time_embedding(
57
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
58
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
59
+ context = self.text_embedding(context)
60
+
61
+ if self.has_image_input:
62
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
63
+ clip_embdding = self.img_emb(clip_feature)
64
+ context = torch.cat([clip_embdding, context], dim=1)
65
+
66
+ x, (f, h, w) = self.patchify(x)
67
+
68
+ freqs = torch.cat([
69
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
70
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
71
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
72
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
73
+
74
+ def create_custom_forward(module):
75
+ def custom_forward(*inputs):
76
+ return module(*inputs)
77
+ return custom_forward
78
+
79
+ # Context Parallel
80
+ x = torch.chunk(
81
+ x, get_sequence_parallel_world_size(),
82
+ dim=1)[get_sequence_parallel_rank()]
83
+
84
+ for block in self.blocks:
85
+ if self.training and use_gradient_checkpointing:
86
+ if use_gradient_checkpointing_offload:
87
+ with torch.autograd.graph.save_on_cpu():
88
+ x = torch.utils.checkpoint.checkpoint(
89
+ create_custom_forward(block),
90
+ x, context, t_mod, freqs,
91
+ use_reentrant=False,
92
+ )
93
+ else:
94
+ x = torch.utils.checkpoint.checkpoint(
95
+ create_custom_forward(block),
96
+ x, context, t_mod, freqs,
97
+ use_reentrant=False,
98
+ )
99
+ else:
100
+ x = block(x, context, t_mod, freqs)
101
+
102
+ x = self.head(x, t)
103
+
104
+ # Context Parallel
105
+ if x.shape[1] * get_sequence_parallel_world_size() < freqs.shape[0]:
106
+ x = torch.cat([x, x[:, -1:]], 1) # TODO: this may cause some bias, the best way is to use sp_size=2
107
+ x = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
108
+ x = x[:, :freqs.shape[0]]
109
+
110
+ # unpatchify
111
+ x = self.unpatchify(x, (f, h, w))
112
+ return x
113
+
114
+
115
+ def usp_attn_forward(self, x, freqs):
116
+ q = self.norm_q(self.q(x))
117
+ k = self.norm_k(self.k(x))
118
+ v = self.v(x)
119
+
120
+ q = rope_apply(q, freqs, self.num_heads)
121
+ k = rope_apply(k, freqs, self.num_heads)
122
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
123
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
124
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
125
+
126
+ x = xFuserLongContextAttention()(
127
+ None,
128
+ query=q,
129
+ key=k,
130
+ value=v,
131
+ )
132
+ x = x.flatten(2)
133
+
134
+ return self.o(x)
OmniAvatar/models/audio_pack.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Union
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+
7
+
8
+ def make_triple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
9
+ value = (value,) * 3 if isinstance(value, int) else value
10
+ assert len(value) == 3
11
+ return value
12
+
13
+
14
+ class AudioPack(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ patch_size: Union[int, Tuple[int, int, int]],
19
+ dim: int,
20
+ layernorm=False,
21
+ ):
22
+ super().__init__()
23
+ t, h, w = make_triple(patch_size)
24
+ self.patch_size = t, h, w
25
+ self.proj = nn.Linear(in_channels * t * h * w, dim)
26
+ if layernorm:
27
+ self.norm_out = nn.LayerNorm(dim)
28
+ else:
29
+ self.norm_out = None
30
+
31
+ def forward(
32
+ self,
33
+ vid: torch.Tensor,
34
+ ) -> torch.Tensor:
35
+ t, h, w = self.patch_size
36
+ vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
37
+ vid = self.proj(vid)
38
+ if self.norm_out is not None:
39
+ vid = self.norm_out(vid)
40
+ return vid
OmniAvatar/models/model_manager.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, json, importlib
2
+ from typing import List
3
+ import torch.nn as nn
4
+ from ..configs.model_config import model_loader_configs
5
+ from ..utils.io_utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix, smart_load_weights
6
+
7
+ class GeneralLoRAFromPeft:
8
+
9
+ def get_name_dict(self, lora_state_dict):
10
+ lora_name_dict = {}
11
+ for key in lora_state_dict:
12
+ if ".lora_B." not in key:
13
+ continue
14
+ keys = key.split(".")
15
+ if len(keys) > keys.index("lora_B") + 2:
16
+ keys.pop(keys.index("lora_B") + 1)
17
+ keys.pop(keys.index("lora_B"))
18
+ if keys[0] == "diffusion_model":
19
+ keys.pop(0)
20
+ target_name = ".".join(keys)
21
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
22
+ return lora_name_dict
23
+
24
+
25
+ def match(self, model: torch.nn.Module, state_dict_lora):
26
+ lora_name_dict = self.get_name_dict(state_dict_lora)
27
+ model_name_dict = {name: None for name, _ in model.named_parameters()}
28
+ matched_num = sum([i in model_name_dict for i in lora_name_dict])
29
+ if matched_num == len(lora_name_dict):
30
+ return "", ""
31
+ else:
32
+ return None
33
+
34
+
35
+ def fetch_device_and_dtype(self, state_dict):
36
+ device, dtype = None, None
37
+ for name, param in state_dict.items():
38
+ device, dtype = param.device, param.dtype
39
+ break
40
+ computation_device = device
41
+ computation_dtype = dtype
42
+ if computation_device == torch.device("cpu"):
43
+ if torch.cuda.is_available():
44
+ computation_device = torch.device("cuda")
45
+ if computation_dtype == torch.float8_e4m3fn:
46
+ computation_dtype = torch.float32
47
+ return device, dtype, computation_device, computation_dtype
48
+
49
+
50
+ def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
51
+ state_dict_model = model.state_dict()
52
+ device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
53
+ lora_name_dict = self.get_name_dict(state_dict_lora)
54
+ for name in lora_name_dict:
55
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
56
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
57
+ if len(weight_up.shape) == 4:
58
+ weight_up = weight_up.squeeze(3).squeeze(2)
59
+ weight_down = weight_down.squeeze(3).squeeze(2)
60
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
61
+ else:
62
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
63
+ weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
64
+ weight_patched = weight_model + weight_lora
65
+ state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
66
+ print(f" {len(lora_name_dict)} tensors are updated.")
67
+ model.load_state_dict(state_dict_model)
68
+
69
+
70
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer):
71
+ loaded_model_names, loaded_models = [], []
72
+ for model_name, model_class in zip(model_names, model_classes):
73
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
74
+ state_dict_converter = model_class.state_dict_converter()
75
+ if model_resource == "civitai":
76
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
77
+ elif model_resource == "diffusers":
78
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
79
+ if isinstance(state_dict_results, tuple):
80
+ model_state_dict, extra_kwargs = state_dict_results
81
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
82
+ else:
83
+ model_state_dict, extra_kwargs = state_dict_results, {}
84
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
85
+ with init_weights_on_device():
86
+ model = model_class(**extra_kwargs)
87
+ if hasattr(model, "eval"):
88
+ model = model.eval()
89
+ if not infer: # 训练才初始化
90
+ model = model.to_empty(device=torch.device("cuda"))
91
+ for name, param in model.named_parameters():
92
+ if param.dim() > 1: # 通常只对权重矩阵而不是偏置做初始化
93
+ nn.init.xavier_uniform_(param, gain=0.05)
94
+ else:
95
+ nn.init.zeros_(param)
96
+ else:
97
+ model = model.to_empty(device=device)
98
+ model, _, _ = smart_load_weights(model, model_state_dict)
99
+ # model.load_state_dict(model_state_dict, assign=True, strict=False)
100
+ model = model.to(dtype=torch_dtype, device=device)
101
+ loaded_model_names.append(model_name)
102
+ loaded_models.append(model)
103
+ return loaded_model_names, loaded_models
104
+
105
+
106
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
107
+ loaded_model_names, loaded_models = [], []
108
+ for model_name, model_class in zip(model_names, model_classes):
109
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
110
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
111
+ else:
112
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
113
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
114
+ model = model.half()
115
+ try:
116
+ model = model.to(device=device)
117
+ except:
118
+ pass
119
+ loaded_model_names.append(model_name)
120
+ loaded_models.append(model)
121
+ return loaded_model_names, loaded_models
122
+
123
+
124
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
125
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
126
+ base_state_dict = base_model.state_dict()
127
+ base_model.to("cpu")
128
+ del base_model
129
+ model = model_class(**extra_kwargs)
130
+ model.load_state_dict(base_state_dict, strict=False)
131
+ model.load_state_dict(state_dict, strict=False)
132
+ model.to(dtype=torch_dtype, device=device)
133
+ return model
134
+
135
+
136
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
137
+ loaded_model_names, loaded_models = [], []
138
+ for model_name, model_class in zip(model_names, model_classes):
139
+ while True:
140
+ for model_id in range(len(model_manager.model)):
141
+ base_model_name = model_manager.model_name[model_id]
142
+ if base_model_name == model_name:
143
+ base_model_path = model_manager.model_path[model_id]
144
+ base_model = model_manager.model[model_id]
145
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
146
+ patched_model = load_single_patch_model_from_single_file(
147
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
148
+ loaded_model_names.append(base_model_name)
149
+ loaded_models.append(patched_model)
150
+ model_manager.model.pop(model_id)
151
+ model_manager.model_path.pop(model_id)
152
+ model_manager.model_name.pop(model_id)
153
+ break
154
+ else:
155
+ break
156
+ return loaded_model_names, loaded_models
157
+
158
+
159
+
160
+ class ModelDetectorTemplate:
161
+ def __init__(self):
162
+ pass
163
+
164
+ def match(self, file_path="", state_dict={}):
165
+ return False
166
+
167
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
168
+ return [], []
169
+
170
+
171
+
172
+ class ModelDetectorFromSingleFile:
173
+ def __init__(self, model_loader_configs=[]):
174
+ self.keys_hash_with_shape_dict = {}
175
+ self.keys_hash_dict = {}
176
+ for metadata in model_loader_configs:
177
+ self.add_model_metadata(*metadata)
178
+
179
+
180
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
181
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
182
+ if keys_hash is not None:
183
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
184
+
185
+
186
+ def match(self, file_path="", state_dict={}):
187
+ if isinstance(file_path, str) and os.path.isdir(file_path):
188
+ return False
189
+ if len(state_dict) == 0:
190
+ state_dict = load_state_dict(file_path)
191
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
192
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
193
+ return True
194
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
195
+ if keys_hash in self.keys_hash_dict:
196
+ return True
197
+ return False
198
+
199
+
200
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, infer=False, **kwargs):
201
+ if len(state_dict) == 0:
202
+ state_dict = load_state_dict(file_path)
203
+
204
+ # Load models with strict matching
205
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
206
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
207
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
208
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
209
+ return loaded_model_names, loaded_models
210
+
211
+ # Load models without strict matching
212
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
213
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
214
+ if keys_hash in self.keys_hash_dict:
215
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
216
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
217
+ return loaded_model_names, loaded_models
218
+
219
+ return loaded_model_names, loaded_models
220
+
221
+
222
+
223
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
224
+ def __init__(self, model_loader_configs=[]):
225
+ super().__init__(model_loader_configs)
226
+
227
+
228
+ def match(self, file_path="", state_dict={}):
229
+ if isinstance(file_path, str) and os.path.isdir(file_path):
230
+ return False
231
+ if len(state_dict) == 0:
232
+ state_dict = load_state_dict(file_path)
233
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
234
+ for sub_state_dict in splited_state_dict:
235
+ if super().match(file_path, sub_state_dict):
236
+ return True
237
+ return False
238
+
239
+
240
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
241
+ # Split the state_dict and load from each component
242
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
243
+ valid_state_dict = {}
244
+ for sub_state_dict in splited_state_dict:
245
+ if super().match(file_path, sub_state_dict):
246
+ valid_state_dict.update(sub_state_dict)
247
+ if super().match(file_path, valid_state_dict):
248
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
249
+ else:
250
+ loaded_model_names, loaded_models = [], []
251
+ for sub_state_dict in splited_state_dict:
252
+ if super().match(file_path, sub_state_dict):
253
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
254
+ loaded_model_names += loaded_model_names_
255
+ loaded_models += loaded_models_
256
+ return loaded_model_names, loaded_models
257
+
258
+
259
+
260
+ class ModelDetectorFromPatchedSingleFile:
261
+ def __init__(self, model_loader_configs=[]):
262
+ self.keys_hash_with_shape_dict = {}
263
+ for metadata in model_loader_configs:
264
+ self.add_model_metadata(*metadata)
265
+
266
+
267
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
268
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
269
+
270
+
271
+ def match(self, file_path="", state_dict={}):
272
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
273
+ return False
274
+ if len(state_dict) == 0:
275
+ state_dict = load_state_dict(file_path)
276
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
277
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
278
+ return True
279
+ return False
280
+
281
+
282
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
283
+ if len(state_dict) == 0:
284
+ state_dict = load_state_dict(file_path)
285
+
286
+ # Load models with strict matching
287
+ loaded_model_names, loaded_models = [], []
288
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
289
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
290
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
291
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
292
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
293
+ loaded_model_names += loaded_model_names_
294
+ loaded_models += loaded_models_
295
+ return loaded_model_names, loaded_models
296
+
297
+
298
+
299
+ class ModelManager:
300
+ def __init__(
301
+ self,
302
+ torch_dtype=torch.float16,
303
+ device="cuda",
304
+ model_id_list: List = [],
305
+ downloading_priority: List = ["ModelScope", "HuggingFace"],
306
+ file_path_list: List[str] = [],
307
+ infer: bool = False
308
+ ):
309
+ self.torch_dtype = torch_dtype
310
+ self.device = device
311
+ self.model = []
312
+ self.model_path = []
313
+ self.model_name = []
314
+ self.infer = infer
315
+ downloaded_files = []
316
+ self.model_detector = [
317
+ ModelDetectorFromSingleFile(model_loader_configs),
318
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
319
+ ]
320
+ self.load_models(downloaded_files + file_path_list)
321
+
322
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
323
+ if isinstance(file_path, list):
324
+ for file_path_ in file_path:
325
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
326
+ else:
327
+ print(f"Loading LoRA models from file: {file_path}")
328
+ is_loaded = False
329
+ if len(state_dict) == 0:
330
+ state_dict = load_state_dict(file_path)
331
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
332
+ lora = GeneralLoRAFromPeft()
333
+ match_results = lora.match(model, state_dict)
334
+ if match_results is not None:
335
+ print(f" Adding LoRA to {model_name} ({model_path}).")
336
+ lora_prefix, model_resource = match_results
337
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
338
+
339
+
340
+
341
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
342
+ print(f"Loading models from file: {file_path}")
343
+ if len(state_dict) == 0:
344
+ state_dict = load_state_dict(file_path)
345
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device, self.infer)
346
+ for model_name, model in zip(model_names, models):
347
+ self.model.append(model)
348
+ self.model_path.append(file_path)
349
+ self.model_name.append(model_name)
350
+ print(f" The following models are loaded: {model_names}.")
351
+
352
+
353
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
354
+ print(f"Loading models from folder: {file_path}")
355
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
356
+ for model_name, model in zip(model_names, models):
357
+ self.model.append(model)
358
+ self.model_path.append(file_path)
359
+ self.model_name.append(model_name)
360
+ print(f" The following models are loaded: {model_names}.")
361
+
362
+
363
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
364
+ print(f"Loading patch models from file: {file_path}")
365
+ model_names, models = load_patch_model_from_single_file(
366
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
367
+ for model_name, model in zip(model_names, models):
368
+ self.model.append(model)
369
+ self.model_path.append(file_path)
370
+ self.model_name.append(model_name)
371
+ print(f" The following patched models are loaded: {model_names}.")
372
+
373
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
374
+ print(f"Loading models from: {file_path}")
375
+ if device is None: device = self.device
376
+ if torch_dtype is None: torch_dtype = self.torch_dtype
377
+ if isinstance(file_path, list):
378
+ state_dict = {}
379
+ for path in file_path:
380
+ state_dict.update(load_state_dict(path))
381
+ elif os.path.isfile(file_path):
382
+ state_dict = load_state_dict(file_path)
383
+ else:
384
+ state_dict = None
385
+ for model_detector in self.model_detector:
386
+ if model_detector.match(file_path, state_dict):
387
+ model_names, models = model_detector.load(
388
+ file_path, state_dict,
389
+ device=device, torch_dtype=torch_dtype,
390
+ allowed_model_names=model_names, model_manager=self, infer=self.infer
391
+ )
392
+ for model_name, model in zip(model_names, models):
393
+ self.model.append(model)
394
+ self.model_path.append(file_path)
395
+ self.model_name.append(model_name)
396
+ print(f" The following models are loaded: {model_names}.")
397
+ break
398
+ else:
399
+ print(f" We cannot detect the model type. No models are loaded.")
400
+
401
+
402
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
403
+ for file_path in file_path_list:
404
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
405
+
406
+
407
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
408
+ fetched_models = []
409
+ fetched_model_paths = []
410
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
411
+ if file_path is not None and file_path != model_path:
412
+ continue
413
+ if model_name == model_name_:
414
+ fetched_models.append(model)
415
+ fetched_model_paths.append(model_path)
416
+ if len(fetched_models) == 0:
417
+ print(f"No {model_name} models available.")
418
+ return None
419
+ if len(fetched_models) == 1:
420
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
421
+ else:
422
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
423
+ if require_model_path:
424
+ return fetched_models[0], fetched_model_paths[0]
425
+ else:
426
+ return fetched_models[0]
427
+
428
+
429
+ def to(self, device):
430
+ for model in self.model:
431
+ model.to(device)
432
+
OmniAvatar/models/vsa_util.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import functools
3
+ import math
4
+ from dataclasses import dataclass
5
+ import torch
6
+ from vsa import video_sparse_attn
7
+ from typing import Any
8
+
9
+ VSA_TILE_SIZE = (4, 4, 4)
10
+
11
+
12
+ @functools.lru_cache(maxsize=10)
13
+ def get_tile_partition_indices(
14
+ dit_seq_shape: tuple[int, int, int],
15
+ tile_size: tuple[int, int, int],
16
+ device: torch.device,
17
+ ) -> torch.LongTensor:
18
+ T, H, W = dit_seq_shape
19
+ ts, hs, ws = tile_size
20
+ indices = torch.arange(T * H * W, device=device,
21
+ dtype=torch.long).reshape(T, H, W)
22
+ ls = []
23
+ for t in range(math.ceil(T / ts)):
24
+ for h in range(math.ceil(H / hs)):
25
+ for w in range(math.ceil(W / ws)):
26
+ ls.append(indices[t * ts:min(t * ts + ts, T),
27
+ h * hs:min(h * hs + hs, H),
28
+ w * ws:min(w * ws + ws, W)].flatten())
29
+ index = torch.cat(ls, dim=0)
30
+ return index
31
+
32
+
33
+ @functools.lru_cache(maxsize=10)
34
+ def get_reverse_tile_partition_indices(
35
+ dit_seq_shape: tuple[int, int, int],
36
+ tile_size: tuple[int, int, int],
37
+ device: torch.device,
38
+ ) -> torch.LongTensor:
39
+ return torch.argsort(
40
+ get_tile_partition_indices(dit_seq_shape, tile_size, device))
41
+
42
+
43
+ @functools.lru_cache(maxsize=10)
44
+ def construct_variable_block_sizes(
45
+ dit_seq_shape: tuple[int, int, int],
46
+ num_tiles: tuple[int, int, int],
47
+ device: torch.device,
48
+ ) -> torch.LongTensor:
49
+ """
50
+ Compute the number of valid (non‑padded) tokens inside every
51
+ (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
52
+ (t‑tile, h‑tile, w‑tile) that `rearrange` uses.
53
+
54
+ Returns
55
+ -------
56
+ torch.LongTensor # shape: [∏ full_window_size]
57
+ """
58
+ # unpack
59
+ t, h, w = dit_seq_shape
60
+ ts_t, ts_h, ts_w = VSA_TILE_SIZE
61
+ n_t, n_h, n_w = num_tiles
62
+
63
+ def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
64
+ """Vector with the size of each tile along one dimension."""
65
+ sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
66
+ # size of last (possibly partial) tile
67
+ remainder = dim_len - (n_tiles - 1) * tile
68
+ sizes[-1] = remainder if remainder > 0 else tile
69
+ return sizes
70
+
71
+ t_sizes = _sizes(t, ts_t, n_t) # [n_t]
72
+ h_sizes = _sizes(h, ts_h, n_h) # [n_h]
73
+ w_sizes = _sizes(w, ts_w, n_w) # [n_w]
74
+
75
+ # broadcast‑multiply to get voxels per tile, then flatten
76
+ block_sizes = (
77
+ t_sizes[:, None, None] # [n_t, 1, 1]
78
+ * h_sizes[None, :, None] # [1, n_h, 1]
79
+ * w_sizes[None, None, :] # [1, 1, n_w]
80
+ ).reshape(-1) # [n_t * n_h * n_w]
81
+
82
+ return block_sizes
83
+
84
+
85
+ @functools.lru_cache(maxsize=10)
86
+ def get_non_pad_index(
87
+ variable_block_sizes: torch.LongTensor,
88
+ max_block_size: int,
89
+ ):
90
+ n_win = variable_block_sizes.shape[0]
91
+ device = variable_block_sizes.device
92
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
93
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size,
94
+ device=device)[None, :]
95
+ index_mask = torch.arange(
96
+ max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
97
+ return index_pad[index_mask]
98
+
99
+
100
+
101
+ @dataclass
102
+ class VideoSparseAttentionMetadata():
103
+ current_timestep: int
104
+ dit_seq_shape: list[int]
105
+ VSA_sparsity: float
106
+ num_tiles: list[int]
107
+ total_seq_length: int
108
+ tile_partition_indices: torch.LongTensor
109
+ reverse_tile_partition_indices: torch.LongTensor
110
+ variable_block_sizes: torch.LongTensor
111
+ non_pad_index: torch.LongTensor
112
+
113
+
114
+ def build(
115
+ current_timestep: int,
116
+ raw_latent_shape: tuple[int, int, int],
117
+ patch_size: tuple[int, int, int],
118
+ VSA_sparsity: float,
119
+ device: torch.device,
120
+ **kwargs: dict[str, Any],
121
+ ) -> VideoSparseAttentionMetadata:
122
+ patch_size = patch_size
123
+ dit_seq_shape = (raw_latent_shape[0] // patch_size[0],
124
+ raw_latent_shape[1] // patch_size[1],
125
+ raw_latent_shape[2] // patch_size[2])
126
+
127
+ num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
128
+ math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
129
+ math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]))
130
+ total_seq_length = math.prod(dit_seq_shape)
131
+
132
+ tile_partition_indices = get_tile_partition_indices(
133
+ dit_seq_shape, VSA_TILE_SIZE, device)
134
+ reverse_tile_partition_indices = get_reverse_tile_partition_indices(
135
+ dit_seq_shape, VSA_TILE_SIZE, device)
136
+ variable_block_sizes = construct_variable_block_sizes(
137
+ dit_seq_shape, num_tiles, device)
138
+ non_pad_index = get_non_pad_index(variable_block_sizes,
139
+ math.prod(VSA_TILE_SIZE))
140
+
141
+ return VideoSparseAttentionMetadata(
142
+ current_timestep=current_timestep,
143
+ dit_seq_shape=dit_seq_shape, # type: ignore
144
+ VSA_sparsity=VSA_sparsity, # type: ignore
145
+ num_tiles=num_tiles, # type: ignore
146
+ total_seq_length=total_seq_length, # type: ignore
147
+ tile_partition_indices=tile_partition_indices, # type: ignore
148
+ reverse_tile_partition_indices=reverse_tile_partition_indices,
149
+ variable_block_sizes=variable_block_sizes,
150
+ non_pad_index=non_pad_index)
151
+
152
+
153
+
154
+ class VideoSparseAttentionImpl():
155
+
156
+ def __init__(
157
+ self,
158
+ num_heads: int,
159
+ head_size: int,
160
+ causal: bool,
161
+ softmax_scale: float,
162
+ num_kv_heads: int | None = None,
163
+ prefix: str = "",
164
+ **extra_impl_args,
165
+ ) -> None:
166
+ self.prefix = prefix
167
+
168
+ def tile(self, x: torch.Tensor, num_tiles: list[int],
169
+ tile_partition_indices: torch.LongTensor,
170
+ non_pad_index: torch.LongTensor) -> torch.Tensor:
171
+ t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
172
+ h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
173
+ w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
174
+
175
+ x_padded = torch.zeros(
176
+ (x.shape[0], t_padded_size * h_padded_size * w_padded_size,
177
+ x.shape[-2], x.shape[-1]),
178
+ device=x.device,
179
+ dtype=x.dtype)
180
+ x_padded[:, non_pad_index] = x[:, tile_partition_indices]
181
+ return x_padded
182
+
183
+ def untile(self, x: torch.Tensor,
184
+ reverse_tile_partition_indices: torch.LongTensor,
185
+ non_pad_index: torch.LongTensor) -> torch.Tensor:
186
+ x = x[:, non_pad_index][:, reverse_tile_partition_indices]
187
+ return x
188
+
189
+ def preprocess_qkv(
190
+ self,
191
+ qkv: torch.Tensor,
192
+ attn_metadata: VideoSparseAttentionMetadata,
193
+ ) -> torch.Tensor:
194
+ return self.tile(qkv, attn_metadata.num_tiles,
195
+ attn_metadata.tile_partition_indices,
196
+ attn_metadata.non_pad_index)
197
+
198
+ def postprocess_output(
199
+ self,
200
+ output: torch.Tensor,
201
+ attn_metadata: VideoSparseAttentionMetadata,
202
+ ) -> torch.Tensor:
203
+ return self.untile(output, attn_metadata.reverse_tile_partition_indices,
204
+ attn_metadata.non_pad_index)
205
+
206
+ def forward( # type: ignore[override]
207
+ self,
208
+ query: torch.Tensor,
209
+ key: torch.Tensor,
210
+ value: torch.Tensor,
211
+ attn_metadata: VideoSparseAttentionMetadata,
212
+ ) -> torch.Tensor:
213
+ query = query.transpose(1, 2).contiguous()
214
+ key = key.transpose(1, 2).contiguous()
215
+ value = value.transpose(1, 2).contiguous()
216
+
217
+ VSA_sparsity = attn_metadata.VSA_sparsity
218
+
219
+ cur_topk = math.ceil(
220
+ (1 - VSA_sparsity) *
221
+ (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)))
222
+
223
+ hidden_states = video_sparse_attn(
224
+ query,
225
+ key,
226
+ value,
227
+ variable_block_sizes=attn_metadata.variable_block_sizes,
228
+ topk=cur_topk,
229
+ block_size=VSA_TILE_SIZE,
230
+ compress_attn_weight=None).transpose(1, 2)
231
+
232
+ return hidden_states
OmniAvatar/models/wan_video_dit.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Tuple, Optional
6
+ from einops import rearrange
7
+ from ..utils.io_utils import hash_state_dict_keys
8
+ from .audio_pack import AudioPack
9
+ from ..utils.args_config import args
10
+
11
+ if args.sp_size > 1:
12
+ # Context Parallel
13
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
14
+ get_sequence_parallel_world_size,
15
+ get_sp_group)
16
+
17
+
18
+ try:
19
+ import flash_attn_interface
20
+ print('using flash_attn_interface')
21
+ FLASH_ATTN_3_AVAILABLE = True
22
+ except ModuleNotFoundError:
23
+ FLASH_ATTN_3_AVAILABLE = False
24
+
25
+ try:
26
+ import flash_attn
27
+ print('using flash_attn')
28
+ FLASH_ATTN_2_AVAILABLE = True
29
+ except ModuleNotFoundError:
30
+ FLASH_ATTN_2_AVAILABLE = False
31
+
32
+ try:
33
+ from sageattention import sageattn
34
+ print('using sageattention')
35
+ SAGE_ATTN_AVAILABLE = True
36
+ except ModuleNotFoundError:
37
+ SAGE_ATTN_AVAILABLE = False
38
+
39
+
40
+ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
41
+ if compatibility_mode:
42
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
43
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
44
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
45
+ x = F.scaled_dot_product_attention(q, k, v)
46
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
47
+ elif FLASH_ATTN_3_AVAILABLE:
48
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
49
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
50
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
51
+ x = flash_attn_interface.flash_attn_func(q, k, v)
52
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
53
+ elif FLASH_ATTN_2_AVAILABLE:
54
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
55
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
56
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
57
+ x = flash_attn.flash_attn_func(q, k, v)
58
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
59
+ elif SAGE_ATTN_AVAILABLE:
60
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
61
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
62
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
63
+ x = sageattn(q, k, v)
64
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
65
+ else:
66
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
67
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
68
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
69
+ x = F.scaled_dot_product_attention(q, k, v)
70
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
71
+ return x
72
+
73
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
74
+ return (x * (1 + scale) + shift)
75
+
76
+
77
+ def sinusoidal_embedding_1d(dim, position):
78
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
79
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
80
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
81
+ return x.to(position.dtype)
82
+
83
+ def precompute_freqs_cos_sin(dim: int, end: int = 1024, theta: float = 10000.0):
84
+ # dim is the per-head dim
85
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float64)[:(dim//2)] / dim))
86
+ angles = torch.outer(torch.arange(end, dtype=torch.float64, device=freqs.device), freqs) # [end, dim//2]
87
+ return angles.cos().to(torch.float32), angles.sin().to(torch.float32)
88
+
89
+ def precompute_freqs_cos_sin_3d(dim: int, end: int = 1024, theta: float = 10000.0):
90
+ fdim = dim - 2 * (dim // 3)
91
+ hdim = dim // 3
92
+ wdim = dim // 3
93
+ fcos, fsin = precompute_freqs_cos_sin(fdim, end, theta)
94
+ hcos, hsin = precompute_freqs_cos_sin(hdim, end, theta)
95
+ wcos, wsin = precompute_freqs_cos_sin(wdim, end, theta)
96
+ return (fcos, hcos, wcos), (fsin, hsin, wsin)
97
+
98
+ def rope_apply_real(x, cos, sin, num_heads):
99
+ # x: [b, s, n*head_dim]
100
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
101
+ # split last dim into pairs
102
+ d2 = x.shape[-1] // 2
103
+ x = x.reshape(*x.shape[:-1], d2, 2) # [..., d/2, 2]
104
+ x1, x2 = x[..., 0], x[..., 1] # two real halves
105
+
106
+ # cos/sin are shaped [seq, 1, d/2]; broadcast across batch/heads
107
+ rot_x1 = x1 * cos - x2 * sin
108
+ rot_x2 = x1 * sin + x2 * cos
109
+ out = torch.stack((rot_x1, rot_x2), dim=-1).reshape(*x.shape[:-2], -1)
110
+
111
+ return rearrange(out, "b s n d -> b s (n d)")
112
+
113
+ class RMSNorm(nn.Module):
114
+ def __init__(self, dim, eps=1e-5):
115
+ super().__init__()
116
+ self.eps = eps
117
+ self.weight = nn.Parameter(torch.ones(dim))
118
+
119
+ def norm(self, x):
120
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
121
+
122
+ def forward(self, x):
123
+ dtype = x.dtype
124
+ return self.norm(x.float()).to(dtype) * self.weight
125
+
126
+
127
+ class AttentionModule(nn.Module):
128
+ def __init__(self, num_heads):
129
+ super().__init__()
130
+ self.num_heads = num_heads
131
+
132
+
133
+ def forward(self, q, k, v):
134
+
135
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
136
+
137
+ return x
138
+
139
+ class SelfAttention(nn.Module):
140
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
141
+ super().__init__()
142
+ self.dim = dim
143
+ self.num_heads = num_heads
144
+ self.head_dim = dim // num_heads
145
+
146
+ self.q = nn.Linear(dim, dim)
147
+ self.k = nn.Linear(dim, dim)
148
+ self.v = nn.Linear(dim, dim)
149
+ self.o = nn.Linear(dim, dim)
150
+ self.norm_q = RMSNorm(dim, eps=eps)
151
+ self.norm_k = RMSNorm(dim, eps=eps)
152
+
153
+ self.attn = AttentionModule(self.num_heads)
154
+
155
+ def forward(self, x, freqs):
156
+
157
+ cos, sin = freqs
158
+
159
+ q = self.norm_q(self.q(x))
160
+ k = self.norm_k(self.k(x))
161
+ v = self.v(x)
162
+ # q = rope_apply(q, freqs, self.num_heads)
163
+ # k = rope_apply(k, freqs, self.num_heads)
164
+
165
+ q = rope_apply_real(q, cos, sin, self.num_heads)
166
+ k = rope_apply_real(k, cos, sin, self.num_heads)
167
+ x = self.attn(q, k, v)
168
+ return self.o(x)
169
+
170
+
171
+ class CrossAttention(nn.Module):
172
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
173
+ super().__init__()
174
+ self.dim = dim
175
+ self.num_heads = num_heads
176
+ self.head_dim = dim // num_heads
177
+
178
+ self.q = nn.Linear(dim, dim)
179
+ self.k = nn.Linear(dim, dim)
180
+ self.v = nn.Linear(dim, dim)
181
+ self.o = nn.Linear(dim, dim)
182
+ self.norm_q = RMSNorm(dim, eps=eps)
183
+ self.norm_k = RMSNorm(dim, eps=eps)
184
+ self.has_image_input = has_image_input
185
+ if has_image_input:
186
+ self.k_img = nn.Linear(dim, dim)
187
+ self.v_img = nn.Linear(dim, dim)
188
+ self.norm_k_img = RMSNorm(dim, eps=eps)
189
+
190
+ self.attn = AttentionModule(self.num_heads)
191
+
192
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
193
+ if self.has_image_input:
194
+ img = y[:, :257]
195
+ ctx = y[:, 257:]
196
+ else:
197
+ ctx = y
198
+ q = self.norm_q(self.q(x))
199
+ k = self.norm_k(self.k(ctx))
200
+ v = self.v(ctx)
201
+ x = self.attn(q, k, v)
202
+ if self.has_image_input:
203
+ k_img = self.norm_k_img(self.k_img(img))
204
+ v_img = self.v_img(img)
205
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
206
+ x = x + y
207
+ return self.o(x)
208
+
209
+
210
+ class GateModule(nn.Module):
211
+ def __init__(self,):
212
+ super().__init__()
213
+
214
+ def forward(self, x, gate, residual):
215
+ return x + gate * residual
216
+
217
+
218
+ class DiTBlock(nn.Module):
219
+ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
220
+ super().__init__()
221
+ self.dim = dim
222
+ self.num_heads = num_heads
223
+ self.ffn_dim = ffn_dim
224
+
225
+ self.self_attn = SelfAttention(dim, num_heads, eps)
226
+ self.cross_attn = CrossAttention(
227
+ dim, num_heads, eps, has_image_input=has_image_input)
228
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
229
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
230
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
231
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
232
+ approximate='tanh'), nn.Linear(ffn_dim, dim))
233
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
234
+ self.gate = GateModule()
235
+
236
+ def forward(self, x, context, t_mod, freqs):
237
+ # msa: multi-head self-attention mlp: multi-layer perceptron
238
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
239
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
240
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
241
+ x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
242
+ x = x + self.cross_attn(self.norm3(x), context)
243
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
244
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
245
+ return x
246
+
247
+ class MLP(nn.Module):
248
+ def __init__(self, in_dim, out_dim):
249
+ super().__init__()
250
+ # keep norms outside the MLP core
251
+ self.ln_in = nn.LayerNorm(in_dim)
252
+ self.fc1 = nn.Linear(in_dim, in_dim)
253
+
254
+ self.activation = nn.GELU()
255
+ self.fc2 = nn.Linear(in_dim, out_dim)
256
+ self.ln_out = nn.LayerNorm(out_dim)
257
+
258
+
259
+ def forward(self, x):
260
+ x = self.ln_in(x)
261
+ x = self.fc2(self.activation(self.fc1(x)))
262
+ x = self.ln_out(x)
263
+ return x
264
+
265
+ class Head(nn.Module):
266
+ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
267
+ super().__init__()
268
+ self.dim = dim
269
+ self.patch_size = patch_size
270
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
271
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
272
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
273
+
274
+ def forward(self, x, t_mod):
275
+ shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
276
+ x = (self.head(self.norm(x) * (1 + scale) + shift))
277
+ return x
278
+
279
+
280
+
281
+ class WanModel(torch.nn.Module):
282
+ def __init__(
283
+ self,
284
+ dim: int,
285
+ in_dim: int,
286
+ ffn_dim: int,
287
+ out_dim: int,
288
+ text_dim: int,
289
+ freq_dim: int,
290
+ eps: float,
291
+ patch_size: Tuple[int, int, int],
292
+ num_heads: int,
293
+ num_layers: int,
294
+ has_image_input: bool,
295
+ audio_hidden_size: int=32,
296
+ ):
297
+ super().__init__()
298
+ self.dim = dim
299
+ self.freq_dim = freq_dim
300
+ self.has_image_input = has_image_input
301
+ self.patch_size = patch_size
302
+
303
+ self.patch_embedding = nn.Conv3d(
304
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
305
+ # nn.LayerNorm(dim)
306
+ self.text_embedding = nn.Sequential(
307
+ nn.Linear(text_dim, dim),
308
+ nn.GELU(approximate='tanh'),
309
+ nn.Linear(dim, dim)
310
+ )
311
+ self.time_embedding = nn.Sequential(
312
+ nn.Linear(freq_dim, dim),
313
+ nn.SiLU(),
314
+ nn.Linear(dim, dim)
315
+ )
316
+ self.time_projection = nn.Sequential(
317
+ nn.SiLU(), nn.Linear(dim, dim * 6))
318
+ self.blocks = nn.ModuleList([
319
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
320
+ for _ in range(num_layers)
321
+ ])
322
+ self.head = Head(dim, out_dim, patch_size, eps)
323
+ head_dim = dim // num_heads
324
+ self.freqs = precompute_freqs_cos_sin_3d(head_dim)
325
+
326
+ if has_image_input:
327
+ self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
328
+
329
+ if 'use_audio' in args:
330
+ self.use_audio = args.use_audio
331
+ else:
332
+ self.use_audio = False
333
+ if self.use_audio:
334
+ audio_input_dim = 10752
335
+ audio_out_dim = dim
336
+ self.audio_proj = AudioPack(audio_input_dim, [4, 1, 1], audio_hidden_size, layernorm=True)
337
+ self.audio_cond_projs = nn.ModuleList()
338
+ for d in range(num_layers // 2 - 1):
339
+ l = nn.Linear(audio_hidden_size, audio_out_dim)
340
+ self.audio_cond_projs.append(l)
341
+
342
+ def patchify(self, x: torch.Tensor):
343
+ grid_size = x.shape[2:]
344
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
345
+ return x, grid_size # x, grid_size: (f, h, w)
346
+
347
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
348
+ return rearrange(
349
+ x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
350
+ f=grid_size[0], h=grid_size[1], w=grid_size[2],
351
+ x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
352
+ )
353
+
354
+ def forward(self,
355
+ x: torch.Tensor,
356
+ timestep: torch.Tensor,
357
+ context: torch.Tensor,
358
+ clip_feature: Optional[torch.Tensor] = None,
359
+ y: Optional[torch.Tensor] = None,
360
+ use_gradient_checkpointing: bool = False,
361
+ audio_emb: Optional[torch.Tensor] = None,
362
+ use_gradient_checkpointing_offload: bool = False,
363
+ tea_cache = None,
364
+ **kwargs,
365
+ ):
366
+
367
+ t = self.time_embedding(
368
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
369
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
370
+ context = self.text_embedding(context)
371
+ lat_h, lat_w = x.shape[-2], x.shape[-1]
372
+
373
+ if audio_emb != None and self.use_audio: # TODO cache
374
+ audio_emb = audio_emb.permute(0, 2, 1)[:, :, :, None, None]
375
+ audio_emb = torch.cat([audio_emb[:, :, :1].repeat(1, 1, 3, 1, 1), audio_emb], 2) # 1, 768, 44, 1, 1
376
+ audio_emb = self.audio_proj(audio_emb)
377
+
378
+ audio_emb = torch.concat([audio_cond_proj(audio_emb) for audio_cond_proj in self.audio_cond_projs], 0)
379
+
380
+ x = torch.cat([x, y], dim=1)
381
+ x = self.patch_embedding(x)
382
+ x, (f, h, w) = self.patchify(x)
383
+
384
+ # freqs = torch.cat([
385
+ # self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
386
+ # self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
387
+ # self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
388
+ # ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
389
+
390
+ (fcos, hcos, wcos), (fsin, hsin, wsin) = self.freqs
391
+ cos = torch.cat([
392
+ fcos[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
393
+ hcos[:h].view(1, h, 1, -1).expand(f, h, w, -1),
394
+ wcos[:w].view(1, 1, w, -1).expand(f, h, w, -1),
395
+ ], dim=-1).reshape(f*h*w, 1, -1).to(x.device, dtype=x.dtype)
396
+ sin = torch.cat([
397
+ fsin[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
398
+ hsin[:h].view(1, h, 1, -1).expand(f, h, w, -1),
399
+ wsin[:w].view(1, 1, w, -1).expand(f, h, w, -1),
400
+ ], dim=-1).reshape(f*h*w, 1, -1).to(x.device, dtype=x.dtype)
401
+ freqs = (cos, sin) # pass both
402
+
403
+ def create_custom_forward(module):
404
+ def custom_forward(*inputs):
405
+ return module(*inputs)
406
+ return custom_forward
407
+
408
+ if tea_cache is not None:
409
+ tea_cache_update = tea_cache.check(self, x, t_mod)
410
+ else:
411
+ tea_cache_update = False
412
+ ori_x_len = x.shape[1]
413
+ if tea_cache_update:
414
+ x = tea_cache.update(x)
415
+ else:
416
+ if args.sp_size > 1:
417
+ # Context Parallel
418
+ sp_size = get_sequence_parallel_world_size()
419
+ pad_size = 0
420
+ if ori_x_len % sp_size != 0:
421
+ pad_size = sp_size - ori_x_len % sp_size
422
+ x = torch.cat([x, torch.zeros_like(x[:, -1:]).repeat(1, pad_size, 1)], 1)
423
+ x = torch.chunk(x, sp_size, dim=1)[get_sequence_parallel_rank()]
424
+
425
+ if self.use_audio:
426
+ audio_emb = audio_emb.reshape(x.shape[0], audio_emb.shape[0] // x.shape[0], -1, *audio_emb.shape[2:])
427
+
428
+ for layer_i, block in enumerate(self.blocks):
429
+ # audio cond
430
+ if self.use_audio:
431
+ au_idx = None
432
+ if (layer_i <= len(self.blocks) // 2 and layer_i > 1): # < len(self.blocks) - 1:
433
+ au_idx = layer_i - 2
434
+ audio_emb_tmp = audio_emb[:, au_idx].repeat(1, 1, lat_h // 2, lat_w // 2, 1) # 1, 11, 45, 25, 128
435
+ audio_cond_tmp = self.patchify(audio_emb_tmp.permute(0, 4, 1, 2, 3))[0]
436
+ if args.sp_size > 1:
437
+ if pad_size > 0:
438
+ audio_cond_tmp = torch.cat([audio_cond_tmp, torch.zeros_like(audio_cond_tmp[:, -1:]).repeat(1, pad_size, 1)], 1)
439
+ audio_cond_tmp = torch.chunk(audio_cond_tmp, sp_size, dim=1)[get_sequence_parallel_rank()]
440
+ x = audio_cond_tmp + x
441
+
442
+ if self.training and use_gradient_checkpointing:
443
+ if use_gradient_checkpointing_offload:
444
+ with torch.autograd.graph.save_on_cpu():
445
+ x = torch.utils.checkpoint.checkpoint(
446
+ create_custom_forward(block),
447
+ x, context, t_mod, freqs,
448
+ use_reentrant=False,
449
+ )
450
+ else:
451
+ x = torch.utils.checkpoint.checkpoint(
452
+ create_custom_forward(block),
453
+ x, context, t_mod, freqs,
454
+ use_reentrant=False,
455
+ )
456
+ else:
457
+ x = block(x, context, t_mod, freqs)
458
+ if tea_cache is not None:
459
+ x_cache = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
460
+ x_cache = x_cache[:, :ori_x_len]
461
+ tea_cache.store(x_cache)
462
+
463
+ x = self.head(x, t)
464
+ if args.sp_size > 1:
465
+ # Context Parallel
466
+ x = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
467
+ x = x[:, :ori_x_len]
468
+
469
+ x = self.unpatchify(x, (f, h, w))
470
+ return x
471
+
472
+ @staticmethod
473
+ def state_dict_converter():
474
+ return WanModelStateDictConverter()
475
+
476
+
477
+ class WanModelStateDictConverter:
478
+ def __init__(self):
479
+ pass
480
+
481
+ def from_diffusers(self, state_dict):
482
+ rename_dict = {
483
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
484
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
485
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
486
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
487
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
488
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
489
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
490
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
491
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
492
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
493
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
494
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
495
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
496
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
497
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
498
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
499
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
500
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
501
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
502
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
503
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
504
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
505
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
506
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
507
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
508
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
509
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
510
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
511
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
512
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
513
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
514
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
515
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
516
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
517
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
518
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
519
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
520
+ "patch_embedding.bias": "patch_embedding.bias",
521
+ "patch_embedding.weight": "patch_embedding.weight",
522
+ "scale_shift_table": "head.modulation",
523
+ "proj_out.bias": "head.head.bias",
524
+ "proj_out.weight": "head.head.weight",
525
+ }
526
+ state_dict_ = {}
527
+ for name, param in state_dict.items():
528
+ if name in rename_dict:
529
+ state_dict_[rename_dict[name]] = param
530
+ else:
531
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
532
+ if name_ in rename_dict:
533
+ name_ = rename_dict[name_]
534
+ name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
535
+ state_dict_[name_] = param
536
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
537
+ config = {
538
+ "model_type": "t2v",
539
+ "patch_size": (1, 2, 2),
540
+ "text_len": 512,
541
+ "in_dim": 16,
542
+ "dim": 5120,
543
+ "ffn_dim": 13824,
544
+ "freq_dim": 256,
545
+ "text_dim": 4096,
546
+ "out_dim": 16,
547
+ "num_heads": 40,
548
+ "num_layers": 40,
549
+ "window_size": (-1, -1),
550
+ "qk_norm": True,
551
+ "cross_attn_norm": True,
552
+ "eps": 1e-6,
553
+ }
554
+ else:
555
+ config = {}
556
+ return state_dict_, config
557
+
558
+ def from_civitai(self, state_dict):
559
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
560
+ config = {
561
+ "has_image_input": False,
562
+ "patch_size": [1, 2, 2],
563
+ "in_dim": 16,
564
+ "dim": 1536,
565
+ "ffn_dim": 8960,
566
+ "freq_dim": 256,
567
+ "text_dim": 4096,
568
+ "out_dim": 16,
569
+ "num_heads": 12,
570
+ "num_layers": 30,
571
+ "eps": 1e-6
572
+ }
573
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
574
+ config = {
575
+ "has_image_input": False,
576
+ "patch_size": [1, 2, 2],
577
+ "in_dim": 16,
578
+ "dim": 5120,
579
+ "ffn_dim": 13824,
580
+ "freq_dim": 256,
581
+ "text_dim": 4096,
582
+ "out_dim": 16,
583
+ "num_heads": 40,
584
+ "num_layers": 40,
585
+ "eps": 1e-6
586
+ }
587
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
588
+ config = {
589
+ "has_image_input": True,
590
+ "patch_size": [1, 2, 2],
591
+ "in_dim": 36,
592
+ "dim": 5120,
593
+ "ffn_dim": 13824,
594
+ "freq_dim": 256,
595
+ "text_dim": 4096,
596
+ "out_dim": 16,
597
+ "num_heads": 40,
598
+ "num_layers": 40,
599
+ "eps": 1e-6
600
+ }
601
+ else:
602
+ config = {}
603
+ if hasattr(args, "model_config"):
604
+ model_config = args.model_config
605
+ if model_config is not None:
606
+ config.update(model_config)
607
+ return state_dict, config
OmniAvatar/models/wan_video_text_encoder.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+
17
+ def forward(self, x):
18
+ return 0.5 * x * (1.0 + torch.tanh(
19
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
20
+
21
+
22
+ class T5LayerNorm(nn.Module):
23
+
24
+ def __init__(self, dim, eps=1e-6):
25
+ super(T5LayerNorm, self).__init__()
26
+ self.dim = dim
27
+ self.eps = eps
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x):
31
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
32
+ self.eps)
33
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
34
+ x = x.type_as(self.weight)
35
+ return self.weight * x
36
+
37
+
38
+ class T5Attention(nn.Module):
39
+
40
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
41
+ assert dim_attn % num_heads == 0
42
+ super(T5Attention, self).__init__()
43
+ self.dim = dim
44
+ self.dim_attn = dim_attn
45
+ self.num_heads = num_heads
46
+ self.head_dim = dim_attn // num_heads
47
+
48
+ # layers
49
+ self.q = nn.Linear(dim, dim_attn, bias=False)
50
+ self.k = nn.Linear(dim, dim_attn, bias=False)
51
+ self.v = nn.Linear(dim, dim_attn, bias=False)
52
+ self.o = nn.Linear(dim_attn, dim, bias=False)
53
+ self.dropout = nn.Dropout(dropout)
54
+
55
+ def forward(self, x, context=None, mask=None, pos_bias=None):
56
+ """
57
+ x: [B, L1, C].
58
+ context: [B, L2, C] or None.
59
+ mask: [B, L2] or [B, L1, L2] or None.
60
+ """
61
+ # check inputs
62
+ context = x if context is None else context
63
+ b, n, c = x.size(0), self.num_heads, self.head_dim
64
+
65
+ # compute query, key, value
66
+ q = self.q(x).view(b, -1, n, c)
67
+ k = self.k(context).view(b, -1, n, c)
68
+ v = self.v(context).view(b, -1, n, c)
69
+
70
+ # attention bias
71
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
72
+ if pos_bias is not None:
73
+ attn_bias += pos_bias
74
+ if mask is not None:
75
+ assert mask.ndim in [2, 3]
76
+ mask = mask.view(b, 1, 1,
77
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
78
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
79
+
80
+ # compute attention (T5 does not use scaling)
81
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
82
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
83
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
84
+
85
+ # output
86
+ x = x.reshape(b, -1, n * c)
87
+ x = self.o(x)
88
+ x = self.dropout(x)
89
+ return x
90
+
91
+
92
+ class T5FeedForward(nn.Module):
93
+
94
+ def __init__(self, dim, dim_ffn, dropout=0.1):
95
+ super(T5FeedForward, self).__init__()
96
+ self.dim = dim
97
+ self.dim_ffn = dim_ffn
98
+
99
+ # layers
100
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
101
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
102
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x) * self.gate(x)
107
+ x = self.dropout(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout(x)
110
+ return x
111
+
112
+
113
+ class T5SelfAttention(nn.Module):
114
+
115
+ def __init__(self,
116
+ dim,
117
+ dim_attn,
118
+ dim_ffn,
119
+ num_heads,
120
+ num_buckets,
121
+ shared_pos=True,
122
+ dropout=0.1):
123
+ super(T5SelfAttention, self).__init__()
124
+ self.dim = dim
125
+ self.dim_attn = dim_attn
126
+ self.dim_ffn = dim_ffn
127
+ self.num_heads = num_heads
128
+ self.num_buckets = num_buckets
129
+ self.shared_pos = shared_pos
130
+
131
+ # layers
132
+ self.norm1 = T5LayerNorm(dim)
133
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
134
+ self.norm2 = T5LayerNorm(dim)
135
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
136
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
137
+ num_buckets, num_heads, bidirectional=True)
138
+
139
+ def forward(self, x, mask=None, pos_bias=None):
140
+ e = pos_bias if self.shared_pos else self.pos_embedding(
141
+ x.size(1), x.size(1))
142
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
143
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
144
+ return x
145
+
146
+
147
+ class T5RelativeEmbedding(nn.Module):
148
+
149
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
150
+ super(T5RelativeEmbedding, self).__init__()
151
+ self.num_buckets = num_buckets
152
+ self.num_heads = num_heads
153
+ self.bidirectional = bidirectional
154
+ self.max_dist = max_dist
155
+
156
+ # layers
157
+ self.embedding = nn.Embedding(num_buckets, num_heads)
158
+
159
+ def forward(self, lq, lk):
160
+ device = self.embedding.weight.device
161
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
162
+ # torch.arange(lq).unsqueeze(1).to(device)
163
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
164
+ torch.arange(lq, device=device).unsqueeze(1)
165
+ rel_pos = self._relative_position_bucket(rel_pos)
166
+ rel_pos_embeds = self.embedding(rel_pos)
167
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
168
+ 0) # [1, N, Lq, Lk]
169
+ return rel_pos_embeds.contiguous()
170
+
171
+ def _relative_position_bucket(self, rel_pos):
172
+ # preprocess
173
+ if self.bidirectional:
174
+ num_buckets = self.num_buckets // 2
175
+ rel_buckets = (rel_pos > 0).long() * num_buckets
176
+ rel_pos = torch.abs(rel_pos)
177
+ else:
178
+ num_buckets = self.num_buckets
179
+ rel_buckets = 0
180
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
181
+
182
+ # embeddings for small and large positions
183
+ max_exact = num_buckets // 2
184
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
185
+ math.log(self.max_dist / max_exact) *
186
+ (num_buckets - max_exact)).long()
187
+ rel_pos_large = torch.min(
188
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
189
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
190
+ return rel_buckets
191
+
192
+ def init_weights(m):
193
+ if isinstance(m, T5LayerNorm):
194
+ nn.init.ones_(m.weight)
195
+ elif isinstance(m, T5FeedForward):
196
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
197
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
198
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
199
+ elif isinstance(m, T5Attention):
200
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
201
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
202
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
203
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
204
+ elif isinstance(m, T5RelativeEmbedding):
205
+ nn.init.normal_(
206
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
207
+
208
+
209
+ class WanTextEncoder(torch.nn.Module):
210
+
211
+ def __init__(self,
212
+ vocab=256384,
213
+ dim=4096,
214
+ dim_attn=4096,
215
+ dim_ffn=10240,
216
+ num_heads=64,
217
+ num_layers=24,
218
+ num_buckets=32,
219
+ shared_pos=False,
220
+ dropout=0.1):
221
+ super(WanTextEncoder, self).__init__()
222
+ self.dim = dim
223
+ self.dim_attn = dim_attn
224
+ self.dim_ffn = dim_ffn
225
+ self.num_heads = num_heads
226
+ self.num_layers = num_layers
227
+ self.num_buckets = num_buckets
228
+ self.shared_pos = shared_pos
229
+
230
+ # layers
231
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
232
+ else nn.Embedding(vocab, dim)
233
+ self.pos_embedding = T5RelativeEmbedding(
234
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
235
+ self.dropout = nn.Dropout(dropout)
236
+ self.blocks = nn.ModuleList([
237
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
238
+ shared_pos, dropout) for _ in range(num_layers)
239
+ ])
240
+ self.norm = T5LayerNorm(dim)
241
+
242
+ # initialize weights
243
+ self.apply(init_weights)
244
+
245
+ def forward(self, ids, mask=None):
246
+ x = self.token_embedding(ids)
247
+ x = self.dropout(x)
248
+ e = self.pos_embedding(x.size(1),
249
+ x.size(1)) if self.shared_pos else None
250
+ for block in self.blocks:
251
+ x = block(x, mask, pos_bias=e)
252
+ x = self.norm(x)
253
+ x = self.dropout(x)
254
+ return x
255
+
256
+ @staticmethod
257
+ def state_dict_converter():
258
+ return WanTextEncoderStateDictConverter()
259
+
260
+
261
+ class WanTextEncoderStateDictConverter:
262
+ def __init__(self):
263
+ pass
264
+
265
+ def from_diffusers(self, state_dict):
266
+ return state_dict
267
+
268
+ def from_civitai(self, state_dict):
269
+ return state_dict
OmniAvatar/models/wan_video_vae.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ CACHE_T = 2
9
+
10
+
11
+ def check_is_instance(model, module_class):
12
+ if isinstance(model, module_class):
13
+ return True
14
+ if hasattr(model, "module") and isinstance(model.module, module_class):
15
+ return True
16
+ return False
17
+
18
+
19
+ def block_causal_mask(x, block_size):
20
+ # params
21
+ b, n, s, _, device = *x.size(), x.device
22
+ assert s % block_size == 0
23
+ num_blocks = s // block_size
24
+
25
+ # build mask
26
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
27
+ for i in range(num_blocks):
28
+ mask[:, :,
29
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
30
+ return mask
31
+
32
+
33
+ class CausalConv3d(nn.Conv3d):
34
+ """
35
+ Causal 3d convolusion.
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
41
+ self.padding[1], 2 * self.padding[0], 0)
42
+ self.padding = (0, 0, 0)
43
+
44
+ def forward(self, x, cache_x=None):
45
+ padding = list(self._padding)
46
+ if cache_x is not None and self._padding[4] > 0:
47
+ cache_x = cache_x.to(x.device)
48
+ x = torch.cat([cache_x, x], dim=2)
49
+ padding[4] -= cache_x.shape[2]
50
+ x = F.pad(x, padding)
51
+
52
+ return super().forward(x)
53
+
54
+
55
+ class RMS_norm(nn.Module):
56
+
57
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
58
+ super().__init__()
59
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
60
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
61
+
62
+ self.channel_first = channel_first
63
+ self.scale = dim**0.5
64
+ self.gamma = nn.Parameter(torch.ones(shape))
65
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
66
+
67
+ def forward(self, x):
68
+ return F.normalize(
69
+ x, dim=(1 if self.channel_first else
70
+ -1)) * self.scale * self.gamma + self.bias
71
+
72
+
73
+ class Upsample(nn.Upsample):
74
+
75
+ def forward(self, x):
76
+ """
77
+ Fix bfloat16 support for nearest neighbor interpolation.
78
+ """
79
+ return super().forward(x.float()).type_as(x)
80
+
81
+
82
+ class Resample(nn.Module):
83
+
84
+ def __init__(self, dim, mode):
85
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
86
+ 'downsample3d')
87
+ super().__init__()
88
+ self.dim = dim
89
+ self.mode = mode
90
+
91
+ # layers
92
+ if mode == 'upsample2d':
93
+ self.resample = nn.Sequential(
94
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
95
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
96
+ elif mode == 'upsample3d':
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
99
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
100
+ self.time_conv = CausalConv3d(dim,
101
+ dim * 2, (3, 1, 1),
102
+ padding=(1, 0, 0))
103
+
104
+ elif mode == 'downsample2d':
105
+ self.resample = nn.Sequential(
106
+ nn.ZeroPad2d((0, 1, 0, 1)),
107
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
+ elif mode == 'downsample3d':
109
+ self.resample = nn.Sequential(
110
+ nn.ZeroPad2d((0, 1, 0, 1)),
111
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
+ self.time_conv = CausalConv3d(dim,
113
+ dim, (3, 1, 1),
114
+ stride=(2, 1, 1),
115
+ padding=(0, 0, 0))
116
+
117
+ else:
118
+ self.resample = nn.Identity()
119
+
120
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
121
+ b, c, t, h, w = x.size()
122
+ if self.mode == 'upsample3d':
123
+ if feat_cache is not None:
124
+ idx = feat_idx[0]
125
+ if feat_cache[idx] is None:
126
+ feat_cache[idx] = 'Rep'
127
+ feat_idx[0] += 1
128
+ else:
129
+
130
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
131
+ if cache_x.shape[2] < 2 and feat_cache[
132
+ idx] is not None and feat_cache[idx] != 'Rep':
133
+ # cache last frame of last two chunk
134
+ cache_x = torch.cat([
135
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
136
+ cache_x.device), cache_x
137
+ ],
138
+ dim=2)
139
+ if cache_x.shape[2] < 2 and feat_cache[
140
+ idx] is not None and feat_cache[idx] == 'Rep':
141
+ cache_x = torch.cat([
142
+ torch.zeros_like(cache_x).to(cache_x.device),
143
+ cache_x
144
+ ],
145
+ dim=2)
146
+ if feat_cache[idx] == 'Rep':
147
+ x = self.time_conv(x)
148
+ else:
149
+ x = self.time_conv(x, feat_cache[idx])
150
+ feat_cache[idx] = cache_x
151
+ feat_idx[0] += 1
152
+
153
+ x = x.reshape(b, 2, c, t, h, w)
154
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
+ 3)
156
+ x = x.reshape(b, c, t * 2, h, w)
157
+ t = x.shape[2]
158
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
159
+ x = self.resample(x)
160
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
161
+
162
+ if self.mode == 'downsample3d':
163
+ if feat_cache is not None:
164
+ idx = feat_idx[0]
165
+ if feat_cache[idx] is None:
166
+ feat_cache[idx] = x.clone()
167
+ feat_idx[0] += 1
168
+ else:
169
+ cache_x = x[:, :, -1:, :, :].clone()
170
+ x = self.time_conv(
171
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
+ feat_cache[idx] = cache_x
173
+ feat_idx[0] += 1
174
+ return x
175
+
176
+ def init_weight(self, conv):
177
+ conv_weight = conv.weight
178
+ nn.init.zeros_(conv_weight)
179
+ c1, c2, t, h, w = conv_weight.size()
180
+ one_matrix = torch.eye(c1, c2)
181
+ init_matrix = one_matrix
182
+ nn.init.zeros_(conv_weight)
183
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
184
+ conv.weight.data.copy_(conv_weight)
185
+ nn.init.zeros_(conv.bias.data)
186
+
187
+ def init_weight2(self, conv):
188
+ conv_weight = conv.weight.data
189
+ nn.init.zeros_(conv_weight)
190
+ c1, c2, t, h, w = conv_weight.size()
191
+ init_matrix = torch.eye(c1 // 2, c2)
192
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
+ conv.weight.data.copy_(conv_weight)
195
+ nn.init.zeros_(conv.bias.data)
196
+
197
+
198
+ class ResidualBlock(nn.Module):
199
+
200
+ def __init__(self, in_dim, out_dim, dropout=0.0):
201
+ super().__init__()
202
+ self.in_dim = in_dim
203
+ self.out_dim = out_dim
204
+
205
+ # layers
206
+ self.residual = nn.Sequential(
207
+ RMS_norm(in_dim, images=False), nn.SiLU(),
208
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
209
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
210
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
211
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
212
+ if in_dim != out_dim else nn.Identity()
213
+
214
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
215
+ h = self.shortcut(x)
216
+ for layer in self.residual:
217
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
218
+ idx = feat_idx[0]
219
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
220
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
221
+ # cache last frame of last two chunk
222
+ cache_x = torch.cat([
223
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
224
+ cache_x.device), cache_x
225
+ ],
226
+ dim=2)
227
+ x = layer(x, feat_cache[idx])
228
+ feat_cache[idx] = cache_x
229
+ feat_idx[0] += 1
230
+ else:
231
+ x = layer(x)
232
+ return x + h
233
+
234
+
235
+ class AttentionBlock(nn.Module):
236
+ """
237
+ Causal self-attention with a single head.
238
+ """
239
+
240
+ def __init__(self, dim):
241
+ super().__init__()
242
+ self.dim = dim
243
+
244
+ # layers
245
+ self.norm = RMS_norm(dim)
246
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
247
+ self.proj = nn.Conv2d(dim, dim, 1)
248
+
249
+ # zero out the last layer params
250
+ nn.init.zeros_(self.proj.weight)
251
+
252
+ def forward(self, x):
253
+ identity = x
254
+ b, c, t, h, w = x.size()
255
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
256
+ x = self.norm(x)
257
+ # compute query, key, value
258
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
259
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
260
+
261
+ # apply attention
262
+ x = F.scaled_dot_product_attention(
263
+ q,
264
+ k,
265
+ v,
266
+ #attn_mask=block_causal_mask(q, block_size=h * w)
267
+ )
268
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
269
+
270
+ # output
271
+ x = self.proj(x)
272
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
273
+ return x + identity
274
+
275
+
276
+ class Encoder3d(nn.Module):
277
+
278
+ def __init__(self,
279
+ dim=128,
280
+ z_dim=4,
281
+ dim_mult=[1, 2, 4, 4],
282
+ num_res_blocks=2,
283
+ attn_scales=[],
284
+ temperal_downsample=[True, True, False],
285
+ dropout=0.0):
286
+ super().__init__()
287
+ self.dim = dim
288
+ self.z_dim = z_dim
289
+ self.dim_mult = dim_mult
290
+ self.num_res_blocks = num_res_blocks
291
+ self.attn_scales = attn_scales
292
+ self.temperal_downsample = temperal_downsample
293
+
294
+ # dimensions
295
+ dims = [dim * u for u in [1] + dim_mult]
296
+ scale = 1.0
297
+
298
+ # init block
299
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
300
+
301
+ # downsample blocks
302
+ downsamples = []
303
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
304
+ # residual (+attention) blocks
305
+ for _ in range(num_res_blocks):
306
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
307
+ if scale in attn_scales:
308
+ downsamples.append(AttentionBlock(out_dim))
309
+ in_dim = out_dim
310
+
311
+ # downsample block
312
+ if i != len(dim_mult) - 1:
313
+ mode = 'downsample3d' if temperal_downsample[
314
+ i] else 'downsample2d'
315
+ downsamples.append(Resample(out_dim, mode=mode))
316
+ scale /= 2.0
317
+ self.downsamples = nn.Sequential(*downsamples)
318
+
319
+ # middle blocks
320
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
321
+ AttentionBlock(out_dim),
322
+ ResidualBlock(out_dim, out_dim, dropout))
323
+
324
+ # output blocks
325
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
326
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
327
+
328
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
329
+ if feat_cache is not None:
330
+ idx = feat_idx[0]
331
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
332
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
333
+ # cache last frame of last two chunk
334
+ cache_x = torch.cat([
335
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
336
+ cache_x.device), cache_x
337
+ ],
338
+ dim=2)
339
+ x = self.conv1(x, feat_cache[idx])
340
+ feat_cache[idx] = cache_x
341
+ feat_idx[0] += 1
342
+ else:
343
+ x = self.conv1(x)
344
+
345
+ ## downsamples
346
+ for layer in self.downsamples:
347
+ if feat_cache is not None:
348
+ x = layer(x, feat_cache, feat_idx)
349
+ else:
350
+ x = layer(x)
351
+
352
+ ## middle
353
+ for layer in self.middle:
354
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
355
+ x = layer(x, feat_cache, feat_idx)
356
+ else:
357
+ x = layer(x)
358
+
359
+ ## head
360
+ for layer in self.head:
361
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
362
+ idx = feat_idx[0]
363
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
364
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
365
+ # cache last frame of last two chunk
366
+ cache_x = torch.cat([
367
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
368
+ cache_x.device), cache_x
369
+ ],
370
+ dim=2)
371
+ x = layer(x, feat_cache[idx])
372
+ feat_cache[idx] = cache_x
373
+ feat_idx[0] += 1
374
+ else:
375
+ x = layer(x)
376
+ return x
377
+
378
+
379
+ class Decoder3d(nn.Module):
380
+
381
+ def __init__(self,
382
+ dim=128,
383
+ z_dim=4,
384
+ dim_mult=[1, 2, 4, 4],
385
+ num_res_blocks=2,
386
+ attn_scales=[],
387
+ temperal_upsample=[False, True, True],
388
+ dropout=0.0):
389
+ super().__init__()
390
+ self.dim = dim
391
+ self.z_dim = z_dim
392
+ self.dim_mult = dim_mult
393
+ self.num_res_blocks = num_res_blocks
394
+ self.attn_scales = attn_scales
395
+ self.temperal_upsample = temperal_upsample
396
+
397
+ # dimensions
398
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
399
+ scale = 1.0 / 2**(len(dim_mult) - 2)
400
+
401
+ # init block
402
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
403
+
404
+ # middle blocks
405
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
406
+ AttentionBlock(dims[0]),
407
+ ResidualBlock(dims[0], dims[0], dropout))
408
+
409
+ # upsample blocks
410
+ upsamples = []
411
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
412
+ # residual (+attention) blocks
413
+ if i == 1 or i == 2 or i == 3:
414
+ in_dim = in_dim // 2
415
+ for _ in range(num_res_blocks + 1):
416
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
417
+ if scale in attn_scales:
418
+ upsamples.append(AttentionBlock(out_dim))
419
+ in_dim = out_dim
420
+
421
+ # upsample block
422
+ if i != len(dim_mult) - 1:
423
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
424
+ upsamples.append(Resample(out_dim, mode=mode))
425
+ scale *= 2.0
426
+ self.upsamples = nn.Sequential(*upsamples)
427
+
428
+ # output blocks
429
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
430
+ CausalConv3d(out_dim, 3, 3, padding=1))
431
+
432
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
433
+ ## conv1
434
+ if feat_cache is not None:
435
+ idx = feat_idx[0]
436
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
437
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
438
+ # cache last frame of last two chunk
439
+ cache_x = torch.cat([
440
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
441
+ cache_x.device), cache_x
442
+ ],
443
+ dim=2)
444
+ x = self.conv1(x, feat_cache[idx])
445
+ feat_cache[idx] = cache_x
446
+ feat_idx[0] += 1
447
+ else:
448
+ x = self.conv1(x)
449
+
450
+ ## middle
451
+ for layer in self.middle:
452
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
453
+ x = layer(x, feat_cache, feat_idx)
454
+ else:
455
+ x = layer(x)
456
+
457
+ ## upsamples
458
+ for layer in self.upsamples:
459
+ if feat_cache is not None:
460
+ x = layer(x, feat_cache, feat_idx)
461
+ else:
462
+ x = layer(x)
463
+
464
+ ## head
465
+ for layer in self.head:
466
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
467
+ idx = feat_idx[0]
468
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
469
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
470
+ # cache last frame of last two chunk
471
+ cache_x = torch.cat([
472
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
473
+ cache_x.device), cache_x
474
+ ],
475
+ dim=2)
476
+ x = layer(x, feat_cache[idx])
477
+ feat_cache[idx] = cache_x
478
+ feat_idx[0] += 1
479
+ else:
480
+ x = layer(x)
481
+ return x
482
+
483
+
484
+ def count_conv3d(model):
485
+ count = 0
486
+ for m in model.modules():
487
+ if check_is_instance(m, CausalConv3d):
488
+ count += 1
489
+ return count
490
+
491
+
492
+ class VideoVAE_(nn.Module):
493
+
494
+ def __init__(self,
495
+ dim=96,
496
+ z_dim=16,
497
+ dim_mult=[1, 2, 4, 4],
498
+ num_res_blocks=2,
499
+ attn_scales=[],
500
+ temperal_downsample=[False, True, True],
501
+ dropout=0.0):
502
+ super().__init__()
503
+ self.dim = dim
504
+ self.z_dim = z_dim
505
+ self.dim_mult = dim_mult
506
+ self.num_res_blocks = num_res_blocks
507
+ self.attn_scales = attn_scales
508
+ self.temperal_downsample = temperal_downsample
509
+ self.temperal_upsample = temperal_downsample[::-1]
510
+
511
+ # modules
512
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
513
+ attn_scales, self.temperal_downsample, dropout)
514
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
515
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
516
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
517
+ attn_scales, self.temperal_upsample, dropout)
518
+
519
+ def forward(self, x):
520
+ mu, log_var = self.encode(x)
521
+ z = self.reparameterize(mu, log_var)
522
+ x_recon = self.decode(z)
523
+ return x_recon, mu, log_var
524
+
525
+ def encode(self, x, scale):
526
+ self.clear_cache()
527
+ ## cache
528
+ t = x.shape[2]
529
+ iter_ = 1 + (t - 1) // 4
530
+
531
+ for i in range(iter_):
532
+ self._enc_conv_idx = [0]
533
+ if i == 0:
534
+ out = self.encoder(x[:, :, :1, :, :],
535
+ feat_cache=self._enc_feat_map,
536
+ feat_idx=self._enc_conv_idx)
537
+ else:
538
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
539
+ feat_cache=self._enc_feat_map,
540
+ feat_idx=self._enc_conv_idx)
541
+ out = torch.cat([out, out_], 2)
542
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
543
+ if isinstance(scale[0], torch.Tensor):
544
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
545
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
546
+ 1, self.z_dim, 1, 1, 1)
547
+ else:
548
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
549
+ mu = (mu - scale[0]) * scale[1]
550
+ return mu
551
+
552
+ def decode(self, z, scale):
553
+ self.clear_cache()
554
+ # z: [b,c,t,h,w]
555
+ if isinstance(scale[0], torch.Tensor):
556
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
557
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
558
+ 1, self.z_dim, 1, 1, 1)
559
+ else:
560
+ scale = scale.to(dtype=z.dtype, device=z.device)
561
+ z = z / scale[1] + scale[0]
562
+ iter_ = z.shape[2]
563
+ x = self.conv2(z)
564
+ for i in range(iter_):
565
+ self._conv_idx = [0]
566
+ if i == 0:
567
+ out = self.decoder(x[:, :, i:i + 1, :, :],
568
+ feat_cache=self._feat_map,
569
+ feat_idx=self._conv_idx)
570
+ else:
571
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
572
+ feat_cache=self._feat_map,
573
+ feat_idx=self._conv_idx)
574
+ out = torch.cat([out, out_], 2) # may add tensor offload
575
+ return out
576
+
577
+ def reparameterize(self, mu, log_var):
578
+ std = torch.exp(0.5 * log_var)
579
+ eps = torch.randn_like(std)
580
+ return eps * std + mu
581
+
582
+ def sample(self, imgs, deterministic=False):
583
+ mu, log_var = self.encode(imgs)
584
+ if deterministic:
585
+ return mu
586
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
587
+ return mu + std * torch.randn_like(std)
588
+
589
+ def clear_cache(self):
590
+ self._conv_num = count_conv3d(self.decoder)
591
+ self._conv_idx = [0]
592
+ self._feat_map = [None] * self._conv_num
593
+ # cache encode
594
+ self._enc_conv_num = count_conv3d(self.encoder)
595
+ self._enc_conv_idx = [0]
596
+ self._enc_feat_map = [None] * self._enc_conv_num
597
+
598
+
599
+ class WanVideoVAE(nn.Module):
600
+
601
+ def __init__(self, z_dim=16):
602
+ super().__init__()
603
+
604
+ mean = [
605
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
606
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
607
+ ]
608
+ std = [
609
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
610
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
611
+ ]
612
+ self.mean = torch.tensor(mean)
613
+ self.std = torch.tensor(std)
614
+ self.scale = [self.mean, 1.0 / self.std]
615
+
616
+ # init model
617
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
618
+ self.upsampling_factor = 8
619
+
620
+
621
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
622
+ x = torch.ones((length,))
623
+ if not left_bound:
624
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
625
+ if not right_bound:
626
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
627
+ return x
628
+
629
+
630
+ def build_mask(self, data, is_bound, border_width):
631
+ _, _, _, H, W = data.shape
632
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
633
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
634
+
635
+ h = repeat(h, "H -> H W", H=H, W=W)
636
+ w = repeat(w, "W -> H W", H=H, W=W)
637
+
638
+ mask = torch.stack([h, w]).min(dim=0).values
639
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
640
+ return mask
641
+
642
+
643
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
644
+ _, _, T, H, W = hidden_states.shape
645
+ size_h, size_w = tile_size
646
+ stride_h, stride_w = tile_stride
647
+
648
+ # Split tasks
649
+ tasks = []
650
+ for h in range(0, H, stride_h):
651
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
652
+ for w in range(0, W, stride_w):
653
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
654
+ h_, w_ = h + size_h, w + size_w
655
+ tasks.append((h, h_, w, w_))
656
+
657
+ data_device = "cpu"
658
+ computation_device = device
659
+
660
+ out_T = T * 4 - 3
661
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
662
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
663
+
664
+ for h, h_, w, w_ in tasks:
665
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
666
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
667
+
668
+ mask = self.build_mask(
669
+ hidden_states_batch,
670
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
671
+ border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
672
+ ).to(dtype=hidden_states.dtype, device=data_device)
673
+
674
+ target_h = h * self.upsampling_factor
675
+ target_w = w * self.upsampling_factor
676
+ values[
677
+ :,
678
+ :,
679
+ :,
680
+ target_h:target_h + hidden_states_batch.shape[3],
681
+ target_w:target_w + hidden_states_batch.shape[4],
682
+ ] += hidden_states_batch * mask
683
+ weight[
684
+ :,
685
+ :,
686
+ :,
687
+ target_h: target_h + hidden_states_batch.shape[3],
688
+ target_w: target_w + hidden_states_batch.shape[4],
689
+ ] += mask
690
+ values = values / weight
691
+ values = values.clamp_(-1, 1)
692
+ return values
693
+
694
+
695
+ def tiled_encode(self, video, device, tile_size, tile_stride):
696
+ _, _, T, H, W = video.shape
697
+ size_h, size_w = tile_size
698
+ stride_h, stride_w = tile_stride
699
+
700
+ # Split tasks
701
+ tasks = []
702
+ for h in range(0, H, stride_h):
703
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
704
+ for w in range(0, W, stride_w):
705
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
706
+ h_, w_ = h + size_h, w + size_w
707
+ tasks.append((h, h_, w, w_))
708
+
709
+ data_device = "cpu"
710
+ computation_device = device
711
+
712
+ out_T = (T + 3) // 4
713
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
714
+ values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
715
+
716
+ for h, h_, w, w_ in tasks:
717
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
718
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
719
+
720
+ mask = self.build_mask(
721
+ hidden_states_batch,
722
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
723
+ border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
724
+ ).to(dtype=video.dtype, device=data_device)
725
+
726
+ target_h = h // self.upsampling_factor
727
+ target_w = w // self.upsampling_factor
728
+ values[
729
+ :,
730
+ :,
731
+ :,
732
+ target_h:target_h + hidden_states_batch.shape[3],
733
+ target_w:target_w + hidden_states_batch.shape[4],
734
+ ] += hidden_states_batch * mask
735
+ weight[
736
+ :,
737
+ :,
738
+ :,
739
+ target_h: target_h + hidden_states_batch.shape[3],
740
+ target_w: target_w + hidden_states_batch.shape[4],
741
+ ] += mask
742
+ values = values / weight
743
+ return values
744
+
745
+
746
+ def single_encode(self, video, device):
747
+ video = video.to(device)
748
+ x = self.model.encode(video, self.scale)
749
+ return x
750
+
751
+
752
+ def single_decode(self, hidden_state, device):
753
+ hidden_state = hidden_state.to(device)
754
+ video = self.model.decode(hidden_state, self.scale)
755
+ return video.clamp_(-1, 1)
756
+
757
+
758
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
759
+
760
+ videos = [video.to("cpu") for video in videos]
761
+ hidden_states = []
762
+ for video in videos:
763
+ video = video.unsqueeze(0)
764
+ if tiled:
765
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
766
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
767
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
768
+ else:
769
+ hidden_state = self.single_encode(video, device)
770
+ hidden_state = hidden_state.squeeze(0)
771
+ hidden_states.append(hidden_state)
772
+ hidden_states = torch.stack(hidden_states)
773
+ return hidden_states
774
+
775
+
776
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
777
+ hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
778
+ videos = []
779
+ for hidden_state in hidden_states:
780
+ hidden_state = hidden_state.unsqueeze(0)
781
+ if tiled:
782
+ video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
783
+ else:
784
+ video = self.single_decode(hidden_state, device)
785
+ video = video.squeeze(0)
786
+ videos.append(video)
787
+ videos = torch.stack(videos)
788
+ return videos
789
+
790
+
791
+ @staticmethod
792
+ def state_dict_converter():
793
+ return WanVideoVAEStateDictConverter()
794
+
795
+
796
+ class WanVideoVAEStateDictConverter:
797
+
798
+ def __init__(self):
799
+ pass
800
+
801
+ def from_civitai(self, state_dict):
802
+ state_dict_ = {}
803
+ if 'model_state' in state_dict:
804
+ state_dict = state_dict['model_state']
805
+ for name in state_dict:
806
+ state_dict_['model.' + name] = state_dict[name]
807
+ return state_dict_
OmniAvatar/models/wav2vec.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=R0901
2
+ # src/models/wav2vec.py
3
+
4
+ """
5
+ This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
6
+ It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
7
+ such as feature extraction and encoding.
8
+
9
+ Classes:
10
+ Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
11
+
12
+ Functions:
13
+ linear_interpolation: Interpolates the features based on the sequence length.
14
+ """
15
+
16
+ import torch.nn.functional as F
17
+ from transformers import Wav2Vec2Model
18
+ from transformers.modeling_outputs import BaseModelOutput
19
+
20
+
21
+ class Wav2VecModel(Wav2Vec2Model):
22
+ """
23
+ Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
24
+ It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
25
+ ...
26
+
27
+ Attributes:
28
+ base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
29
+
30
+ Methods:
31
+ forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
32
+ , output_attentions=None, output_hidden_states=None, return_dict=None):
33
+ Forward pass of the Wav2VecModel.
34
+ It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
35
+
36
+ feature_extract(input_values, seq_len):
37
+ Extracts features from the input_values using the base model.
38
+
39
+ encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
40
+ Encodes the extracted features using the base model and returns the encoded features.
41
+ """
42
+ def forward(
43
+ self,
44
+ input_values,
45
+ seq_len,
46
+ attention_mask=None,
47
+ mask_time_indices=None,
48
+ output_attentions=None,
49
+ output_hidden_states=None,
50
+ return_dict=None,
51
+ ):
52
+ """
53
+ Forward pass of the Wav2Vec model.
54
+
55
+ Args:
56
+ self: The instance of the model.
57
+ input_values: The input values (waveform) to the model.
58
+ seq_len: The sequence length of the input values.
59
+ attention_mask: Attention mask to be used for the model.
60
+ mask_time_indices: Mask indices to be used for the model.
61
+ output_attentions: If set to True, returns attentions.
62
+ output_hidden_states: If set to True, returns hidden states.
63
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
64
+
65
+ Returns:
66
+ The output of the Wav2Vec model.
67
+ """
68
+
69
+ output_hidden_states = (
70
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
71
+ )
72
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
73
+
74
+ extract_features = self.feature_extractor(input_values)
75
+ extract_features = extract_features.transpose(1, 2)
76
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
77
+
78
+ if attention_mask is not None:
79
+ # compute reduced attention_mask corresponding to feature vectors
80
+ attention_mask = self._get_feature_vector_attention_mask(
81
+ extract_features.shape[1], attention_mask, add_adapter=False
82
+ )
83
+
84
+ hidden_states, extract_features = self.feature_projection(extract_features)
85
+ hidden_states = self._mask_hidden_states(
86
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
87
+ )
88
+
89
+ encoder_outputs = self.encoder(
90
+ hidden_states,
91
+ attention_mask=attention_mask,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ hidden_states = encoder_outputs[0]
98
+
99
+ if self.adapter is not None:
100
+ hidden_states = self.adapter(hidden_states)
101
+
102
+ if not return_dict:
103
+ return (hidden_states, ) + encoder_outputs[1:]
104
+ return BaseModelOutput(
105
+ last_hidden_state=hidden_states,
106
+ hidden_states=encoder_outputs.hidden_states,
107
+ attentions=encoder_outputs.attentions,
108
+ )
109
+
110
+
111
+ def feature_extract(
112
+ self,
113
+ input_values,
114
+ seq_len,
115
+ ):
116
+ """
117
+ Extracts features from the input values and returns the extracted features.
118
+
119
+ Parameters:
120
+ input_values (torch.Tensor): The input values to be processed.
121
+ seq_len (torch.Tensor): The sequence lengths of the input values.
122
+
123
+ Returns:
124
+ extracted_features (torch.Tensor): The extracted features from the input values.
125
+ """
126
+ extract_features = self.feature_extractor(input_values)
127
+ extract_features = extract_features.transpose(1, 2)
128
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
129
+
130
+ return extract_features
131
+
132
+ def encode(
133
+ self,
134
+ extract_features,
135
+ attention_mask=None,
136
+ mask_time_indices=None,
137
+ output_attentions=None,
138
+ output_hidden_states=None,
139
+ return_dict=None,
140
+ ):
141
+ """
142
+ Encodes the input features into the output space.
143
+
144
+ Args:
145
+ extract_features (torch.Tensor): The extracted features from the audio signal.
146
+ attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
147
+ mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
148
+ output_attentions (bool, optional): If set to True, returns the attention weights.
149
+ output_hidden_states (bool, optional): If set to True, returns all hidden states.
150
+ return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
151
+
152
+ Returns:
153
+ The encoded output features.
154
+ """
155
+ self.config.output_attentions = True
156
+
157
+ output_hidden_states = (
158
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
159
+ )
160
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
161
+
162
+ if attention_mask is not None:
163
+ # compute reduced attention_mask corresponding to feature vectors
164
+ attention_mask = self._get_feature_vector_attention_mask(
165
+ extract_features.shape[1], attention_mask, add_adapter=False
166
+ )
167
+
168
+ hidden_states, extract_features = self.feature_projection(extract_features)
169
+ hidden_states = self._mask_hidden_states(
170
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
171
+ )
172
+
173
+ encoder_outputs = self.encoder(
174
+ hidden_states,
175
+ attention_mask=attention_mask,
176
+ output_attentions=output_attentions,
177
+ output_hidden_states=output_hidden_states,
178
+ return_dict=return_dict,
179
+ )
180
+
181
+ hidden_states = encoder_outputs[0]
182
+
183
+ if self.adapter is not None:
184
+ hidden_states = self.adapter(hidden_states)
185
+
186
+ if not return_dict:
187
+ return (hidden_states, ) + encoder_outputs[1:]
188
+ return BaseModelOutput(
189
+ last_hidden_state=hidden_states,
190
+ hidden_states=encoder_outputs.hidden_states,
191
+ attentions=encoder_outputs.attentions,
192
+ )
193
+
194
+
195
+ def linear_interpolation(features, seq_len):
196
+ """
197
+ Transpose the features to interpolate linearly.
198
+
199
+ Args:
200
+ features (torch.Tensor): The extracted features to be interpolated.
201
+ seq_len (torch.Tensor): The sequence lengths of the features.
202
+
203
+ Returns:
204
+ torch.Tensor: The interpolated features.
205
+ """
206
+ features = features.transpose(1, 2)
207
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
208
+ return output_features.transpose(1, 2)
OmniAvatar/prompters/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wan_prompter import WanPrompter
OmniAvatar/prompters/base_prompter.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.model_manager import ModelManager
2
+ import torch
3
+
4
+
5
+
6
+ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
7
+ # Get model_max_length from self.tokenizer
8
+ length = tokenizer.model_max_length if max_length is None else max_length
9
+
10
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
11
+ tokenizer.model_max_length = 99999999
12
+
13
+ # Tokenize it!
14
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
15
+
16
+ # Determine the real length.
17
+ max_length = (input_ids.shape[1] + length - 1) // length * length
18
+
19
+ # Restore tokenizer.model_max_length
20
+ tokenizer.model_max_length = length
21
+
22
+ # Tokenize it again with fixed length.
23
+ input_ids = tokenizer(
24
+ prompt,
25
+ return_tensors="pt",
26
+ padding="max_length",
27
+ max_length=max_length,
28
+ truncation=True
29
+ ).input_ids
30
+
31
+ # Reshape input_ids to fit the text encoder.
32
+ num_sentence = input_ids.shape[1] // length
33
+ input_ids = input_ids.reshape((num_sentence, length))
34
+
35
+ return input_ids
36
+
37
+
38
+
39
+ class BasePrompter:
40
+ def __init__(self):
41
+ self.refiners = []
42
+ self.extenders = []
43
+
44
+
45
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
46
+ for refiner_class in refiner_classes:
47
+ refiner = refiner_class.from_model_manager(model_manager)
48
+ self.refiners.append(refiner)
49
+
50
+ def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
51
+ for extender_class in extender_classes:
52
+ extender = extender_class.from_model_manager(model_manager)
53
+ self.extenders.append(extender)
54
+
55
+
56
+ @torch.no_grad()
57
+ def process_prompt(self, prompt, positive=True):
58
+ if isinstance(prompt, list):
59
+ prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
60
+ else:
61
+ for refiner in self.refiners:
62
+ prompt = refiner(prompt, positive=positive)
63
+ return prompt
64
+
65
+ @torch.no_grad()
66
+ def extend_prompt(self, prompt:str, positive=True):
67
+ extended_prompt = dict(prompt=prompt)
68
+ for extender in self.extenders:
69
+ extended_prompt = extender(extended_prompt)
70
+ return extended_prompt
OmniAvatar/prompters/wan_prompter.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_prompter import BasePrompter
2
+ from ..models.wan_video_text_encoder import WanTextEncoder
3
+ from transformers import AutoTokenizer
4
+ import os, torch
5
+ import ftfy
6
+ import html
7
+ import string
8
+ import regex as re
9
+
10
+
11
+ def basic_clean(text):
12
+ text = ftfy.fix_text(text)
13
+ text = html.unescape(html.unescape(text))
14
+ return text.strip()
15
+
16
+
17
+ def whitespace_clean(text):
18
+ text = re.sub(r'\s+', ' ', text)
19
+ text = text.strip()
20
+ return text
21
+
22
+
23
+ def canonicalize(text, keep_punctuation_exact_string=None):
24
+ text = text.replace('_', ' ')
25
+ if keep_punctuation_exact_string:
26
+ text = keep_punctuation_exact_string.join(
27
+ part.translate(str.maketrans('', '', string.punctuation))
28
+ for part in text.split(keep_punctuation_exact_string))
29
+ else:
30
+ text = text.translate(str.maketrans('', '', string.punctuation))
31
+ text = text.lower()
32
+ text = re.sub(r'\s+', ' ', text)
33
+ return text.strip()
34
+
35
+
36
+ class HuggingfaceTokenizer:
37
+
38
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
39
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
40
+ self.name = name
41
+ self.seq_len = seq_len
42
+ self.clean = clean
43
+
44
+ # init tokenizer
45
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
46
+ self.vocab_size = self.tokenizer.vocab_size
47
+
48
+ def __call__(self, sequence, **kwargs):
49
+ return_mask = kwargs.pop('return_mask', False)
50
+
51
+ # arguments
52
+ _kwargs = {'return_tensors': 'pt'}
53
+ if self.seq_len is not None:
54
+ _kwargs.update({
55
+ 'padding': 'max_length',
56
+ 'truncation': True,
57
+ 'max_length': self.seq_len
58
+ })
59
+ _kwargs.update(**kwargs)
60
+
61
+ # tokenization
62
+ if isinstance(sequence, str):
63
+ sequence = [sequence]
64
+ if self.clean:
65
+ sequence = [self._clean(u) for u in sequence]
66
+ ids = self.tokenizer(sequence, **_kwargs)
67
+
68
+ # output
69
+ if return_mask:
70
+ return ids.input_ids, ids.attention_mask
71
+ else:
72
+ return ids.input_ids
73
+
74
+ def _clean(self, text):
75
+ if self.clean == 'whitespace':
76
+ text = whitespace_clean(basic_clean(text))
77
+ elif self.clean == 'lower':
78
+ text = whitespace_clean(basic_clean(text)).lower()
79
+ elif self.clean == 'canonicalize':
80
+ text = canonicalize(basic_clean(text))
81
+ return text
82
+
83
+
84
+ class WanPrompter(BasePrompter):
85
+
86
+ def __init__(self, tokenizer_path=None, text_len=512):
87
+ super().__init__()
88
+ self.text_len = text_len
89
+ self.text_encoder = None
90
+ self.fetch_tokenizer(tokenizer_path)
91
+
92
+ def fetch_tokenizer(self, tokenizer_path=None):
93
+ if tokenizer_path is not None:
94
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
95
+
96
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
97
+ self.text_encoder = text_encoder
98
+
99
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
100
+ prompt = self.process_prompt(prompt, positive=positive)
101
+
102
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
103
+ ids = ids.to(device)
104
+ mask = mask.to(device)
105
+ seq_lens = mask.gt(0).sum(dim=1).long()
106
+ prompt_emb = self.text_encoder(ids, mask)
107
+ for i, v in enumerate(seq_lens):
108
+ prompt_emb[:, v:] = 0
109
+ return prompt_emb
OmniAvatar/schedulers/flow_match.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
8
+ self.num_train_timesteps = num_train_timesteps
9
+ self.shift = shift
10
+ self.sigma_max = sigma_max
11
+ self.sigma_min = sigma_min
12
+ self.inverse_timesteps = inverse_timesteps
13
+ self.extra_one_step = extra_one_step
14
+ self.reverse_sigmas = reverse_sigmas
15
+ self.set_timesteps(num_inference_steps)
16
+
17
+
18
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
19
+ if shift is not None:
20
+ self.shift = shift
21
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
22
+ if self.extra_one_step:
23
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
24
+ else:
25
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
26
+ if self.inverse_timesteps:
27
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
28
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
29
+ if self.reverse_sigmas:
30
+ self.sigmas = 1 - self.sigmas
31
+ self.timesteps = self.sigmas * self.num_train_timesteps
32
+ if training:
33
+ x = self.timesteps
34
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
35
+ y_shifted = y - y.min()
36
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
37
+ self.linear_timesteps_weights = bsmntw_weighing
38
+
39
+
40
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
41
+ if isinstance(timestep, torch.Tensor):
42
+ timestep = timestep.cpu()
43
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
44
+ sigma = self.sigmas[timestep_id]
45
+ if to_final or timestep_id + 1 >= len(self.timesteps):
46
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
47
+ else:
48
+ sigma_ = self.sigmas[timestep_id + 1]
49
+ prev_sample = sample + model_output * (sigma_ - sigma)
50
+ return prev_sample
51
+
52
+
53
+ def return_to_timestep(self, timestep, sample, sample_stablized):
54
+ if isinstance(timestep, torch.Tensor):
55
+ timestep = timestep.cpu()
56
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
57
+ sigma = self.sigmas[timestep_id]
58
+ model_output = (sample - sample_stablized) / sigma
59
+ return model_output
60
+
61
+
62
+ def add_noise(self, original_samples, noise, timestep):
63
+ if isinstance(timestep, torch.Tensor):
64
+ timestep = timestep.cpu()
65
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
66
+ sigma = self.sigmas[timestep_id]
67
+ sample = (1 - sigma) * original_samples + sigma * noise
68
+ return sample
69
+
70
+
71
+ def training_target(self, sample, noise, timestep):
72
+ target = noise - sample
73
+ return target
74
+
75
+
76
+ def training_weight(self, timestep):
77
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
78
+ weights = self.linear_timesteps_weights[timestep_id]
79
+ return weights
OmniAvatar/utils/args_config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import argparse
4
+ import yaml
5
+ args = None
6
+
7
+ def set_global_args(local_args):
8
+ global args
9
+
10
+ args = local_args
11
+
12
+ def parse_hp_string(hp_string):
13
+ result = {}
14
+ for pair in hp_string.split(','):
15
+ if not pair:
16
+ continue
17
+ key, value = pair.split('=')
18
+ try:
19
+ # 自动转换为 int / float / str
20
+ ori_value = value
21
+ value = float(value)
22
+ if '.' not in str(ori_value):
23
+ value = int(value)
24
+ except ValueError:
25
+ pass
26
+
27
+ if value in ['true', 'True']:
28
+ value = True
29
+ if value in ['false', 'False']:
30
+ value = False
31
+ if '.' in key:
32
+ keys = key.split('.')
33
+ keys = keys
34
+ current = result
35
+ for key in keys[:-1]:
36
+ if key not in current or not isinstance(current[key], dict):
37
+ current[key] = {}
38
+ current = current[key]
39
+ current[keys[-1]] = value
40
+ else:
41
+ result[key.strip()] = value
42
+ return result
43
+
44
+ def parse_args():
45
+ global args
46
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
47
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config file.")
48
+
49
+ # 定义 argparse 参数
50
+ parser.add_argument("--exp_path", type=str, help="Path to save the model.")
51
+ parser.add_argument("--input_file", type=str, help="Path to inference txt.")
52
+ parser.add_argument("--debug", action='store_true', default=None)
53
+ parser.add_argument("--infer", action='store_true')
54
+ parser.add_argument("-hp", "--hparams", type=str, default="")
55
+
56
+ args = parser.parse_args()
57
+
58
+ # 读取 YAML 配置(如果提供了 --config 参数)
59
+ if args.config:
60
+ with open(args.config, "r") as f:
61
+ yaml_config = yaml.safe_load(f)
62
+
63
+ # 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义)
64
+ for key, value in yaml_config.items():
65
+ if not hasattr(args, key): # argparse 没有的参数
66
+ setattr(args, key, value)
67
+ elif getattr(args, key) is None: # argparse 有但值为空
68
+ setattr(args, key, value)
69
+
70
+ args.rank = int(os.getenv("RANK", "0"))
71
+ args.world_size = int(os.getenv("WORLD_SIZE", "1"))
72
+ args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
73
+ args.device = 'cuda'
74
+ debug = args.debug
75
+ if not os.path.exists(args.exp_path):
76
+ args.exp_path = f'checkpoints/{args.exp_path}'
77
+
78
+ if hasattr(args, 'reload_cfg') and args.reload_cfg:
79
+ # 重新加载配置文件
80
+ conf_path = os.path.join(args.exp_path, "config.json")
81
+ if os.path.exists(conf_path):
82
+ print('| Reloading config from:', conf_path)
83
+ args = reload(args, conf_path)
84
+ if len(args.hparams) > 0:
85
+ hp_dict = parse_hp_string(args.hparams)
86
+ for key, value in hp_dict.items():
87
+ if not hasattr(args, key):
88
+ setattr(args, key, value)
89
+ else:
90
+ if isinstance(value, dict):
91
+ ori_v = getattr(args, key)
92
+ ori_v.update(value)
93
+ setattr(args, key, ori_v)
94
+ else:
95
+ setattr(args, key, value)
96
+ args.debug = debug
97
+ dict_args = convert_namespace_to_dict(args)
98
+ if args.local_rank == 0:
99
+ print(dict_args)
100
+ return args
101
+
102
+ def reload(args, conf_path):
103
+ """重新加载配置文件,不覆盖已有的参数"""
104
+ with open(conf_path, "r") as f:
105
+ yaml_config = yaml.safe_load(f)
106
+ # 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义)
107
+ for key, value in yaml_config.items():
108
+ if not hasattr(args, key): # argparse 没有的参数
109
+ setattr(args, key, value)
110
+ elif getattr(args, key) is None: # argparse 有但值为空
111
+ setattr(args, key, value)
112
+ return args
113
+
114
+ def convert_namespace_to_dict(namespace):
115
+ """将 argparse.Namespace 转为字典,并处理不可序列化对象"""
116
+ result = {}
117
+ for key, value in vars(namespace).items():
118
+ try:
119
+ json.dumps(value) # 检查是否可序列化
120
+ result[key] = value
121
+ except (TypeError, OverflowError):
122
+ result[key] = str(value) # 将不可序列化的对象转为字符串表示
123
+ return result
OmniAvatar/utils/audio_preprocess.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+
4
+ def add_silence_to_audio_ffmpeg(audio_path, tmp_audio_path, silence_duration_s=0.5):
5
+ # 使用 ffmpeg 命令在音频前加上静音
6
+ cmd = [
7
+ 'ffmpeg',
8
+ '-i', audio_path, # 输入音频文件路径
9
+ '-f', 'lavfi', # 使用 lavfi 虚拟输入设备生成静音
10
+ '-t', str(silence_duration_s), # 静音时长,单位秒
11
+ '-i', 'anullsrc=r=16000:cl=stereo', # 创建静音片段(假设音频为 stereo,采样率 44100)
12
+ '-filter_complex', '[1][0]concat=n=2:v=0:a=1[out]', # 合并静音和原音频
13
+ '-map', '[out]', # 输出合并后的音频
14
+ '-y', tmp_audio_path, # 输出文件路径
15
+ '-loglevel', 'error'
16
+ ]
17
+
18
+ try:
19
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
20
+ except subprocess.CalledProcessError as e:
21
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
OmniAvatar/utils/io_utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import torch, os
3
+ from safetensors import safe_open
4
+ from OmniAvatar.utils.args_config import args
5
+ from contextlib import contextmanager
6
+
7
+ import re
8
+ import tempfile
9
+ import numpy as np
10
+ import imageio
11
+ from glob import glob
12
+ import soundfile as sf
13
+ from einops import rearrange
14
+ import hashlib
15
+
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+ @contextmanager
19
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
20
+
21
+ old_register_parameter = torch.nn.Module.register_parameter
22
+ if include_buffers:
23
+ old_register_buffer = torch.nn.Module.register_buffer
24
+
25
+ def register_empty_parameter(module, name, param):
26
+ old_register_parameter(module, name, param)
27
+ if param is not None:
28
+ param_cls = type(module._parameters[name])
29
+ kwargs = module._parameters[name].__dict__
30
+ kwargs["requires_grad"] = param.requires_grad
31
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
32
+
33
+ def register_empty_buffer(module, name, buffer, persistent=True):
34
+ old_register_buffer(module, name, buffer, persistent=persistent)
35
+ if buffer is not None:
36
+ module._buffers[name] = module._buffers[name].to(device)
37
+
38
+ def patch_tensor_constructor(fn):
39
+ def wrapper(*args, **kwargs):
40
+ kwargs["device"] = device
41
+ return fn(*args, **kwargs)
42
+
43
+ return wrapper
44
+
45
+ if include_buffers:
46
+ tensor_constructors_to_patch = {
47
+ torch_function_name: getattr(torch, torch_function_name)
48
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
49
+ }
50
+ else:
51
+ tensor_constructors_to_patch = {}
52
+
53
+ try:
54
+ torch.nn.Module.register_parameter = register_empty_parameter
55
+ if include_buffers:
56
+ torch.nn.Module.register_buffer = register_empty_buffer
57
+ for torch_function_name in tensor_constructors_to_patch.keys():
58
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
59
+ yield
60
+ finally:
61
+ torch.nn.Module.register_parameter = old_register_parameter
62
+ if include_buffers:
63
+ torch.nn.Module.register_buffer = old_register_buffer
64
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
65
+ setattr(torch, torch_function_name, old_torch_function)
66
+
67
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
68
+ state_dict = {}
69
+ for file_name in os.listdir(file_path):
70
+ if "." in file_name and file_name.split(".")[-1] in [
71
+ "safetensors", "bin", "ckpt", "pth", "pt"
72
+ ]:
73
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
74
+ return state_dict
75
+
76
+
77
+ def load_state_dict(file_path, torch_dtype=None):
78
+ if file_path.endswith(".safetensors"):
79
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
80
+ else:
81
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
82
+
83
+
84
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
85
+ state_dict = {}
86
+ with safe_open(file_path, framework="pt", device="cpu") as f:
87
+ for k in f.keys():
88
+ state_dict[k] = f.get_tensor(k)
89
+ if torch_dtype is not None:
90
+ state_dict[k] = state_dict[k].to(torch_dtype)
91
+ return state_dict
92
+
93
+
94
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
95
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
96
+ if torch_dtype is not None:
97
+ for i in state_dict:
98
+ if isinstance(state_dict[i], torch.Tensor):
99
+ state_dict[i] = state_dict[i].to(torch_dtype)
100
+ return state_dict
101
+
102
+ def smart_load_weights(model, ckpt_state_dict):
103
+ model_state_dict = model.state_dict()
104
+ new_state_dict = {}
105
+
106
+ for name, param in model_state_dict.items():
107
+ if name in ckpt_state_dict:
108
+ ckpt_param = ckpt_state_dict[name]
109
+ if param.shape == ckpt_param.shape:
110
+ new_state_dict[name] = ckpt_param
111
+ else:
112
+ # 自动修剪维度以匹配
113
+ if all(p >= c for p, c in zip(param.shape, ckpt_param.shape)):
114
+ print(f"[Truncate] {name}: ckpt {ckpt_param.shape} -> model {param.shape}")
115
+ # 创建新张量,拷贝旧数据
116
+ new_param = param.clone()
117
+ slices = tuple(slice(0, s) for s in ckpt_param.shape)
118
+ new_param[slices] = ckpt_param
119
+ new_state_dict[name] = new_param
120
+ else:
121
+ print(f"[Skip] {name}: ckpt {ckpt_param.shape} is larger than model {param.shape}")
122
+
123
+ # 更新 state_dict,只更新那些匹配的
124
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, assign=True, strict=False)
125
+ return model, missing_keys, unexpected_keys
126
+
127
+ def save_wav(audio, audio_path):
128
+ if isinstance(audio, torch.Tensor):
129
+ audio = audio.float().detach().cpu().numpy()
130
+
131
+ if audio.ndim == 1:
132
+ audio = np.expand_dims(audio, axis=0) # (1, samples)
133
+
134
+ sf.write(audio_path, audio.T, 16000)
135
+
136
+ return True
137
+
138
+ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: float = 5,prompt=None, prompt_path=None, audio=None, audio_path=None, prefix=None):
139
+ os.makedirs(save_path, exist_ok=True)
140
+ out_videos = []
141
+
142
+ with tempfile.TemporaryDirectory() as tmp_path:
143
+
144
+ print(f'video batch shape:{video_batch.shape}')
145
+
146
+ for i, vid in enumerate(video_batch):
147
+ gif_frames = []
148
+
149
+ for frame in vid:
150
+ ft = frame.detach().cpu().clone()
151
+ ft = rearrange(ft, "c h w -> h w c")
152
+ arr = (255.0 * ft).numpy().astype(np.uint8)
153
+ gif_frames.append(arr)
154
+
155
+ if prefix is not None:
156
+ now_save_path = os.path.join(save_path, f"{prefix}_{i:03d}.mp4")
157
+ tmp_save_path = os.path.join(tmp_path, f"{prefix}_{i:03d}.mp4")
158
+ else:
159
+ now_save_path = os.path.join(save_path, f"{i:03d}.mp4")
160
+ tmp_save_path = os.path.join(tmp_path, f"{i:03d}.mp4")
161
+ with imageio.get_writer(tmp_save_path, fps=fps) as writer:
162
+ for frame in gif_frames:
163
+ writer.append_data(frame)
164
+ subprocess.run([f"cp {tmp_save_path} {now_save_path}"], check=True, shell=True)
165
+ print(f'save res video to : {now_save_path}')
166
+ final_video_path = now_save_path
167
+
168
+ if audio is not None or audio_path is not None:
169
+ if audio is not None:
170
+ audio_path = os.path.join(tmp_path, f"{i:06d}.mp3")
171
+ save_wav(audio[i], audio_path)
172
+ # cmd = f'/usr/bin/ffmpeg -i {tmp_save_path} -i {audio_path} -v quiet -c:v copy -c:a libmp3lame -strict experimental {tmp_save_path[:-4]}_wav.mp4 -y'
173
+ cmd = f'/usr/bin/ffmpeg -i {tmp_save_path} -i {audio_path} -v quiet -map 0:v:0 -map 1:a:0 -c:v copy -c:a aac {tmp_save_path[:-4]}_wav.mp4 -y'
174
+ subprocess.check_call(cmd, stdout=None, stdin=subprocess.PIPE, shell=True)
175
+ final_video_path = f"{now_save_path[:-4]}_wav.mp4"
176
+ subprocess.run([f"cp {tmp_save_path[:-4]}_wav.mp4 {final_video_path}"], check=True, shell=True)
177
+ os.remove(now_save_path)
178
+ if prompt is not None and prompt_path is not None:
179
+ with open(prompt_path, "w") as f:
180
+ f.write(prompt)
181
+ out_videos.append(final_video_path)
182
+
183
+ return out_videos
184
+
185
+ def is_zero_stage_3(trainer):
186
+ strategy = getattr(trainer, "strategy", None)
187
+ if strategy and hasattr(strategy, "model"):
188
+ ds_engine = strategy.model
189
+ stage = ds_engine.config.get("zero_optimization", {}).get("stage", 0)
190
+ return stage == 3
191
+ return False
192
+
193
+ def hash_state_dict_keys(state_dict, with_shape=True):
194
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
195
+ keys_str = keys_str.encode(encoding="UTF-8")
196
+ return hashlib.md5(keys_str).hexdigest()
197
+
198
+ def split_state_dict_with_prefix(state_dict):
199
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
200
+ prefix_dict = {}
201
+ for key in keys:
202
+ prefix = key if "." not in key else key.split(".")[0]
203
+ if prefix not in prefix_dict:
204
+ prefix_dict[prefix] = []
205
+ prefix_dict[prefix].append(key)
206
+ state_dicts = []
207
+ for prefix, keys in prefix_dict.items():
208
+ sub_state_dict = {key: state_dict[key] for key in keys}
209
+ state_dicts.append(sub_state_dict)
210
+ return state_dicts
211
+
212
+ def hash_state_dict_keys(state_dict, with_shape=True):
213
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
214
+ keys_str = keys_str.encode(encoding="UTF-8")
215
+ return hashlib.md5(keys_str).hexdigest()
216
+
217
+ def split_state_dict_with_prefix(state_dict):
218
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
219
+ prefix_dict = {}
220
+ for key in keys:
221
+ prefix = key if "." not in key else key.split(".")[0]
222
+ if prefix not in prefix_dict:
223
+ prefix_dict[prefix] = []
224
+ prefix_dict[prefix].append(key)
225
+ state_dicts = []
226
+ for prefix, keys in prefix_dict.items():
227
+ sub_state_dict = {key: state_dict[key] for key in keys}
228
+ state_dicts.append(sub_state_dict)
229
+ return state_dicts
230
+
231
+ def search_for_files(folder, extensions):
232
+ files = []
233
+ if os.path.isdir(folder):
234
+ for file in sorted(os.listdir(folder)):
235
+ files += search_for_files(os.path.join(folder, file), extensions)
236
+ elif os.path.isfile(folder):
237
+ for extension in extensions:
238
+ if folder.endswith(extension):
239
+ files.append(folder)
240
+ break
241
+ return files
242
+
243
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
244
+ keys = []
245
+ for key, value in state_dict.items():
246
+ if isinstance(key, str):
247
+ if isinstance(value, torch.Tensor):
248
+ if with_shape:
249
+ shape = "_".join(map(str, list(value.shape)))
250
+ keys.append(key + ":" + shape)
251
+ keys.append(key)
252
+ elif isinstance(value, dict):
253
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
254
+ keys.sort()
255
+ keys_str = ",".join(keys)
256
+ return keys_str
OmniAvatar/vram_management/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layers import *
OmniAvatar/vram_management/layers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from ..utils.io_utils import init_weights_on_device
3
+
4
+
5
+ def cast_to(weight, dtype, device):
6
+ r = torch.empty_like(weight, dtype=dtype, device=device)
7
+ r.copy_(weight)
8
+ return r
9
+
10
+
11
+ class AutoWrappedModule(torch.nn.Module):
12
+ def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
13
+ super().__init__()
14
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
15
+ self.offload_dtype = offload_dtype
16
+ self.offload_device = offload_device
17
+ self.onload_dtype = onload_dtype
18
+ self.onload_device = onload_device
19
+ self.computation_dtype = computation_dtype
20
+ self.computation_device = computation_device
21
+ self.state = 0
22
+
23
+ def offload(self):
24
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
25
+ self.module.to(dtype=self.offload_dtype, device=self.offload_device)
26
+ self.state = 0
27
+
28
+ def onload(self):
29
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
30
+ self.module.to(dtype=self.onload_dtype, device=self.onload_device)
31
+ self.state = 1
32
+
33
+ def forward(self, *args, **kwargs):
34
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
35
+ module = self.module
36
+ else:
37
+ module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
38
+ return module(*args, **kwargs)
39
+
40
+
41
+ class AutoWrappedLinear(torch.nn.Linear):
42
+ def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
43
+ with init_weights_on_device(device=torch.device("meta")):
44
+ super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
45
+ self.weight = module.weight
46
+ self.bias = module.bias
47
+ self.offload_dtype = offload_dtype
48
+ self.offload_device = offload_device
49
+ self.onload_dtype = onload_dtype
50
+ self.onload_device = onload_device
51
+ self.computation_dtype = computation_dtype
52
+ self.computation_device = computation_device
53
+ self.state = 0
54
+
55
+ def offload(self):
56
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
57
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
58
+ self.state = 0
59
+
60
+ def onload(self):
61
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
62
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
63
+ self.state = 1
64
+
65
+ def forward(self, x, *args, **kwargs):
66
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
67
+ weight, bias = self.weight, self.bias
68
+ else:
69
+ weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
70
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
71
+ return torch.nn.functional.linear(x, weight, bias)
72
+
73
+
74
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
75
+ for name, module in model.named_children():
76
+ for source_module, target_module in module_map.items():
77
+ if isinstance(module, source_module):
78
+ num_param = sum(p.numel() for p in module.parameters())
79
+ if max_num_param is not None and total_num_param + num_param > max_num_param:
80
+ module_config_ = overflow_module_config
81
+ else:
82
+ module_config_ = module_config
83
+ module_ = target_module(module, **module_config_)
84
+ setattr(model, name, module_)
85
+ total_num_param += num_param
86
+ break
87
+ else:
88
+ total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
89
+ return total_num_param
90
+
91
+
92
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
93
+ enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
94
+ model.vram_management_enabled = True
95
+
OmniAvatar/wan_video.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from .models.model_manager import ModelManager
3
+ from .models.wan_video_dit import WanModel
4
+ from .models.wan_video_text_encoder import WanTextEncoder
5
+ from .models.wan_video_vae import WanVideoVAE
6
+ from .schedulers.flow_match import FlowMatchScheduler
7
+ from .base import BasePipeline
8
+ from .prompters import WanPrompter
9
+ import torch, os
10
+ from einops import rearrange
11
+ import numpy as np
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ from typing import Optional
15
+ from .vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
16
+ from .models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
17
+ from .models.wan_video_dit import RMSNorm
18
+ from .models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
19
+
20
+
21
+ class WanVideoPipeline(BasePipeline):
22
+
23
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
24
+ super().__init__(device=device, torch_dtype=torch_dtype)
25
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
26
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
27
+ self.text_encoder: WanTextEncoder = None
28
+ self.image_encoder = None
29
+ self.dit: WanModel = None
30
+ self.vae: WanVideoVAE = None
31
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
32
+ self.height_division_factor = 16
33
+ self.width_division_factor = 16
34
+ self.use_unified_sequence_parallel = False
35
+ self.sp_size = 1
36
+
37
+
38
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
39
+ dtype = next(iter(self.text_encoder.parameters())).dtype
40
+ enable_vram_management(
41
+ self.text_encoder,
42
+ module_map = {
43
+ torch.nn.Linear: AutoWrappedLinear,
44
+ torch.nn.Embedding: AutoWrappedModule,
45
+ T5RelativeEmbedding: AutoWrappedModule,
46
+ T5LayerNorm: AutoWrappedModule,
47
+ },
48
+ module_config = dict(
49
+ offload_dtype=dtype,
50
+ offload_device="cpu",
51
+ onload_dtype=dtype,
52
+ onload_device="cpu",
53
+ computation_dtype=self.torch_dtype,
54
+ computation_device=self.device,
55
+ ),
56
+ )
57
+ dtype = next(iter(self.dit.parameters())).dtype
58
+ enable_vram_management(
59
+ self.dit,
60
+ module_map = {
61
+ torch.nn.Linear: AutoWrappedLinear,
62
+ torch.nn.Conv3d: AutoWrappedModule,
63
+ torch.nn.LayerNorm: AutoWrappedModule,
64
+ RMSNorm: AutoWrappedModule,
65
+ },
66
+ module_config = dict(
67
+ offload_dtype=dtype,
68
+ offload_device="cpu",
69
+ onload_dtype=dtype,
70
+ onload_device=self.device,
71
+ computation_dtype=self.torch_dtype,
72
+ computation_device=self.device,
73
+ ),
74
+ max_num_param=num_persistent_param_in_dit,
75
+ overflow_module_config = dict(
76
+ offload_dtype=dtype,
77
+ offload_device="cpu",
78
+ onload_dtype=dtype,
79
+ onload_device="cpu",
80
+ computation_dtype=self.torch_dtype,
81
+ computation_device=self.device,
82
+ ),
83
+ )
84
+ dtype = next(iter(self.vae.parameters())).dtype
85
+ enable_vram_management(
86
+ self.vae,
87
+ module_map = {
88
+ torch.nn.Linear: AutoWrappedLinear,
89
+ torch.nn.Conv2d: AutoWrappedModule,
90
+ RMS_norm: AutoWrappedModule,
91
+ CausalConv3d: AutoWrappedModule,
92
+ Upsample: AutoWrappedModule,
93
+ torch.nn.SiLU: AutoWrappedModule,
94
+ torch.nn.Dropout: AutoWrappedModule,
95
+ },
96
+ module_config = dict(
97
+ offload_dtype=dtype,
98
+ offload_device="cpu",
99
+ onload_dtype=dtype,
100
+ onload_device=self.device,
101
+ computation_dtype=self.torch_dtype,
102
+ computation_device=self.device,
103
+ ),
104
+ )
105
+ if self.image_encoder is not None:
106
+ dtype = next(iter(self.image_encoder.parameters())).dtype
107
+ enable_vram_management(
108
+ self.image_encoder,
109
+ module_map = {
110
+ torch.nn.Linear: AutoWrappedLinear,
111
+ torch.nn.Conv2d: AutoWrappedModule,
112
+ torch.nn.LayerNorm: AutoWrappedModule,
113
+ },
114
+ module_config = dict(
115
+ offload_dtype=dtype,
116
+ offload_device="cpu",
117
+ onload_dtype=dtype,
118
+ onload_device="cpu",
119
+ computation_dtype=dtype,
120
+ computation_device=self.device,
121
+ ),
122
+ )
123
+ self.enable_cpu_offload()
124
+
125
+
126
+ def fetch_models(self, model_manager: ModelManager):
127
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
128
+ if text_encoder_model_and_path is not None:
129
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
130
+ self.prompter.fetch_models(self.text_encoder)
131
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
132
+ self.dit = model_manager.fetch_model("wan_video_dit")
133
+ self.vae = model_manager.fetch_model("wan_video_vae")
134
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
135
+
136
+
137
+ @staticmethod
138
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False, infer=False):
139
+ if device is None: device = model_manager.device
140
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
141
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
142
+ pipe.fetch_models(model_manager)
143
+ if use_usp:
144
+ from xfuser.core.distributed import get_sequence_parallel_world_size, get_sp_group
145
+ from OmniAvatar.distributed.xdit_context_parallel import usp_attn_forward
146
+ for block in pipe.dit.blocks:
147
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
148
+ pipe.sp_size = get_sequence_parallel_world_size()
149
+ pipe.use_unified_sequence_parallel = True
150
+ pipe.sp_group = get_sp_group()
151
+ return pipe
152
+
153
+
154
+ def denoising_model(self):
155
+ return self.dit
156
+
157
+
158
+ def encode_prompt(self, prompt, positive=True):
159
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
160
+ return {"context": prompt_emb}
161
+
162
+
163
+ def encode_image(self, image, num_frames, height, width):
164
+ image = self.preprocess_image(image.resize((width, height))).to(self.device, dtype=self.torch_dtype)
165
+ clip_context = self.image_encoder.encode_image([image])
166
+ clip_context = clip_context.to(dtype=self.torch_dtype)
167
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device, dtype=self.torch_dtype)
168
+ msk[:, 1:] = 0
169
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
170
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
171
+ msk = msk.transpose(1, 2)[0]
172
+
173
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device, dtype=self.torch_dtype)], dim=1)
174
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
175
+ y = torch.concat([msk, y])
176
+ y = y.unsqueeze(0)
177
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
178
+ y = y.to(dtype=self.torch_dtype, device=self.device)
179
+ return {"clip_feature": clip_context, "y": y}
180
+
181
+
182
+ def tensor2video(self, frames):
183
+ frames = rearrange(frames, "C T H W -> T H W C")
184
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
185
+ frames = [Image.fromarray(frame) for frame in frames]
186
+ return frames
187
+
188
+
189
+ def prepare_extra_input(self, latents=None):
190
+ return {}
191
+
192
+
193
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
194
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
195
+ return latents
196
+
197
+
198
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
199
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
200
+ return frames
201
+
202
+
203
+ def prepare_unified_sequence_parallel(self):
204
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
205
+
206
+
207
+ @torch.no_grad()
208
+ def log_video(
209
+ self,
210
+ lat,
211
+ prompt,
212
+ fixed_frame=0, # lat frames
213
+ image_emb={},
214
+ audio_emb={},
215
+ negative_prompt="",
216
+ cfg_scale=5.0,
217
+ audio_cfg_scale=5.0,
218
+ num_inference_steps=50,
219
+ denoising_strength=1.0,
220
+ sigma_shift=5.0,
221
+ tiled=True,
222
+ tile_size=(30, 52),
223
+ tile_stride=(15, 26),
224
+ tea_cache_l1_thresh=None,
225
+ tea_cache_model_id="",
226
+ progress_bar_cmd=None,
227
+ return_latent=False,
228
+ ):
229
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
230
+ # Scheduler
231
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
232
+
233
+ lat = lat.to(dtype=self.torch_dtype)
234
+ latents = lat.clone()
235
+ latents = torch.randn_like(latents, dtype=self.torch_dtype)
236
+
237
+ # Encode prompts
238
+ self.load_models_to_device(["text_encoder"])
239
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
240
+ if cfg_scale != 1.0:
241
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
242
+
243
+ # Extra input
244
+ extra_input = self.prepare_extra_input(latents)
245
+
246
+ # TeaCache
247
+ tea_cache_posi = {"tea_cache": None}
248
+ tea_cache_nega = {"tea_cache": None}
249
+
250
+ # Denoise
251
+ self.load_models_to_device(["dit"])
252
+ for progress_id, timestep in enumerate(tqdm(self.scheduler.timesteps) if progress_bar_cmd is None else self.scheduler.timesteps ):
253
+ if fixed_frame > 0: # new
254
+ latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
255
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
256
+
257
+ # Inference
258
+ noise_pred_posi = self.dit(x=latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb, **tea_cache_posi, **extra_input)
259
+
260
+ if cfg_scale != 1.0:
261
+ audio_emb_uc = {}
262
+ for key in audio_emb.keys():
263
+ audio_emb_uc[key] = torch.zeros_like(audio_emb[key], dtype=self.torch_dtype)
264
+ if audio_cfg_scale == cfg_scale:
265
+ noise_pred_nega = self.dit(x=latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
266
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
267
+ else:
268
+ tea_cache_nega_audio = {"tea_cache": None}
269
+ audio_noise_pred_nega = self.dit(x=latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb_uc, **tea_cache_nega_audio, **extra_input)
270
+ text_noise_pred_nega = self.dit(x=latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
271
+ noise_pred = text_noise_pred_nega + cfg_scale * (audio_noise_pred_nega - text_noise_pred_nega) + audio_cfg_scale * (noise_pred_posi - audio_noise_pred_nega)
272
+ else:
273
+ noise_pred = noise_pred_posi
274
+ # Scheduler
275
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
276
+
277
+ if progress_bar_cmd is not None:
278
+ progress_bar_cmd.update(1)
279
+
280
+
281
+ if fixed_frame > 0: # new
282
+ latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
283
+ # Decode
284
+ self.load_models_to_device(['vae'])
285
+ frames = self.decode_video(latents, **tiler_kwargs)
286
+ recons = self.decode_video(lat, **tiler_kwargs)
287
+ self.load_models_to_device([])
288
+ frames = (frames.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
289
+ recons = (recons.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
290
+ if return_latent:
291
+ return frames, recons, latents
292
+ return frames, recons
293
+
294
+
295
+ class TeaCache:
296
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
297
+ self.num_inference_steps = num_inference_steps
298
+ self.step = 0
299
+ self.accumulated_rel_l1_distance = 0
300
+ self.previous_modulated_input = None
301
+ self.rel_l1_thresh = rel_l1_thresh
302
+ self.previous_residual = None
303
+ self.previous_hidden_states = None
304
+
305
+ self.coefficients_dict = {
306
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
307
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
308
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
309
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
310
+ }
311
+ if model_id not in self.coefficients_dict:
312
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
313
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
314
+ self.coefficients = self.coefficients_dict[model_id]
315
+
316
+ def check(self, dit: WanModel, x, t_mod):
317
+ modulated_inp = t_mod.clone()
318
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
319
+ should_calc = True
320
+ self.accumulated_rel_l1_distance = 0
321
+ else:
322
+ coefficients = self.coefficients
323
+ rescale_func = np.poly1d(coefficients)
324
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
325
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
326
+ should_calc = False
327
+ else:
328
+ should_calc = True
329
+ self.accumulated_rel_l1_distance = 0
330
+ self.previous_modulated_input = modulated_inp
331
+ self.step += 1
332
+ if self.step == self.num_inference_steps:
333
+ self.step = 0
334
+ if should_calc:
335
+ self.previous_hidden_states = x.clone()
336
+ return not should_calc
337
+
338
+ def store(self, hidden_states):
339
+ self.previous_residual = hidden_states - self.previous_hidden_states
340
+ self.previous_hidden_states = None
341
+
342
+ def update(self, hidden_states):
343
+ hidden_states = hidden_states + self.previous_residual
344
+ return hidden_states
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: OmniAvatar Clay Fast
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: OmniAvatar-Clay-Fast
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.36.2
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Generate claymation style avatar to do your podcast
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import subprocess
3
+ import gradio as gr
4
+
5
+ import os, sys
6
+ from glob import glob
7
+ from datetime import datetime
8
+ import math
9
+ import random
10
+ import librosa
11
+ import numpy as np
12
+ import uuid
13
+ import shutil
14
+ from tqdm import tqdm
15
+
16
+ import importlib, site, sys
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+
19
+ # Re-discover all .pth/.egg-link files
20
+ for sitedir in site.getsitepackages():
21
+ site.addsitedir(sitedir)
22
+
23
+ # Clear caches so importlib will pick up new modules
24
+ importlib.invalidate_caches()
25
+
26
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
27
+
28
+ flash_attention_installed = False
29
+
30
+ try:
31
+ print("Attempting to download and install FlashAttention wheel...")
32
+ flash_attention_wheel = hf_hub_download(
33
+ repo_id="alexnasa/flash-attn-3",
34
+ repo_type="model",
35
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
36
+ )
37
+
38
+ sh(f"pip install {flash_attention_wheel}")
39
+
40
+ # tell Python to re-scan site-packages now that the egg-link exists
41
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
42
+
43
+ flash_attention_installed = True
44
+ print("FlashAttention installed successfully.")
45
+
46
+ except Exception as e:
47
+ print(f"⚠️ Could not install FlashAttention: {e}")
48
+ print("Continuing without FlashAttention...")
49
+
50
+ import torch
51
+ print(f"Torch version: {torch.__version__}")
52
+ # print(f"FlashAttention available: {flash_attention_installed}")
53
+
54
+
55
+
56
+ import torch.nn as nn
57
+ from tqdm import tqdm
58
+ from functools import partial
59
+ from omegaconf import OmegaConf
60
+ from argparse import Namespace
61
+ from gradio_extendedimage import extendedimage
62
+
63
+ import torchaudio
64
+
65
+ # load the one true config you dumped
66
+ _args_cfg = OmegaConf.load("args_config.yaml")
67
+ args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
68
+
69
+ from OmniAvatar.utils.args_config import set_global_args
70
+
71
+ set_global_args(args)
72
+ # args = parse_args()
73
+
74
+ from OmniAvatar.utils.io_utils import load_state_dict
75
+ from peft import LoraConfig, inject_adapter_in_model
76
+ from OmniAvatar.models.model_manager import ModelManager
77
+ from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
78
+ from OmniAvatar.wan_video import WanVideoPipeline
79
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
80
+ import torchvision.transforms as TT
81
+ from transformers import Wav2Vec2FeatureExtractor
82
+ import torchvision.transforms as transforms
83
+ import torch.nn.functional as F
84
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
85
+
86
+ from diffusers import FluxKontextPipeline
87
+ from diffusers.utils import load_image
88
+
89
+ from PIL import Image
90
+
91
+
92
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
93
+
94
+
95
+ flux_pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
96
+ flux_pipe.load_lora_weights("alexnasa/Claymation-Kontext-Dev-Lora")
97
+ flux_pipe.to("cuda")
98
+ flux_inference = 10
99
+
100
+ def tensor_to_pil(tensor):
101
+ """
102
+ Args:
103
+ tensor: torch.Tensor with shape like
104
+ (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
105
+ values in [-1, 1], on any device.
106
+ Returns:
107
+ A PIL.Image in RGB mode.
108
+ """
109
+ # 1) Remove batch dim if it exists
110
+ if tensor.dim() > 3 and tensor.shape[0] == 1:
111
+ tensor = tensor[0]
112
+
113
+ # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
114
+ tensor = tensor.squeeze()
115
+
116
+ # Now we should have exactly 3 dims: (C, H, W)
117
+ if tensor.dim() != 3:
118
+ raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
119
+
120
+ # 3) Move to CPU float32
121
+ tensor = tensor.cpu().float()
122
+
123
+ # 4) Undo normalization from [-1,1] -> [0,1]
124
+ tensor = (tensor + 1.0) / 2.0
125
+
126
+ # 5) Clamp to [0,1]
127
+ tensor = torch.clamp(tensor, 0.0, 1.0)
128
+
129
+ # 6) To NumPy H×W×C in [0,255]
130
+ np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
131
+
132
+ # 7) Build PIL Image
133
+ return Image.fromarray(np_img)
134
+
135
+
136
+ def set_seed(seed: int = 42):
137
+ random.seed(seed)
138
+ np.random.seed(seed)
139
+ torch.manual_seed(seed)
140
+ torch.cuda.manual_seed(seed) # 设置当前GPU
141
+ torch.cuda.manual_seed_all(seed) # 设置所有GPU
142
+
143
+ def read_from_file(p):
144
+ with open(p, "r") as fin:
145
+ for l in fin:
146
+ yield l.strip()
147
+
148
+ def match_size(image_size, h, w):
149
+ ratio_ = 9999
150
+ size_ = 9999
151
+ select_size = None
152
+ for image_s in image_size:
153
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
154
+ size_tmp = abs(max(image_s) - max(w, h))
155
+ if ratio_tmp < ratio_:
156
+ ratio_ = ratio_tmp
157
+ size_ = size_tmp
158
+ select_size = image_s
159
+ if ratio_ == ratio_tmp:
160
+ if size_ == size_tmp:
161
+ select_size = image_s
162
+ return select_size
163
+
164
+ def resize_pad(image, ori_size, tgt_size):
165
+ h, w = ori_size
166
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
167
+ scale_h = int(h * scale_ratio)
168
+ scale_w = int(w * scale_ratio)
169
+
170
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
171
+
172
+ padding_h = tgt_size[0] - scale_h
173
+ padding_w = tgt_size[1] - scale_w
174
+ pad_top = padding_h // 2
175
+ pad_bottom = padding_h - pad_top
176
+ pad_left = padding_w // 2
177
+ pad_right = padding_w - pad_left
178
+
179
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
180
+ return image
181
+
182
+ class WanInferencePipeline(nn.Module):
183
+ def __init__(self, args):
184
+ super().__init__()
185
+ self.args = args
186
+ self.device = torch.device(f"cuda")
187
+ self.dtype = torch.bfloat16
188
+ self.pipe = self.load_model()
189
+ chained_trainsforms = []
190
+ chained_trainsforms.append(TT.ToTensor())
191
+ self.transform = TT.Compose(chained_trainsforms)
192
+
193
+ if self.args.use_audio:
194
+ from OmniAvatar.models.wav2vec import Wav2VecModel
195
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
196
+ self.args.wav2vec_path
197
+ )
198
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
199
+ self.audio_encoder.feature_extractor._freeze_parameters()
200
+
201
+
202
+ def load_model(self):
203
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
204
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
205
+ if self.args.train_architecture == 'lora':
206
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
207
+ else:
208
+ resume_path = ckpt_path
209
+
210
+ self.step = 0
211
+
212
+ # Load models
213
+ model_manager = ModelManager(device="cuda", infer=True)
214
+
215
+ model_manager.load_models(
216
+ [
217
+ self.args.dit_path.split(","),
218
+ self.args.vae_path,
219
+ self.args.text_encoder_path
220
+ ],
221
+ torch_dtype=self.dtype,
222
+ device='cuda',
223
+ )
224
+
225
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
226
+ torch_dtype=self.dtype,
227
+ device="cuda",
228
+ use_usp=False,
229
+ infer=True)
230
+
231
+ if self.args.train_architecture == "lora":
232
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
233
+ self.add_lora_to_model(
234
+ pipe.denoising_model(),
235
+ lora_rank=self.args.lora_rank,
236
+ lora_alpha=self.args.lora_alpha,
237
+ lora_target_modules=self.args.lora_target_modules,
238
+ init_lora_weights=self.args.init_lora_weights,
239
+ pretrained_lora_path=pretrained_lora_path,
240
+ )
241
+ print(next(pipe.denoising_model().parameters()).device)
242
+ else:
243
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
244
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
245
+ pipe.requires_grad_(False)
246
+ pipe.eval()
247
+ # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
248
+ return pipe
249
+
250
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
251
+ # Add LoRA to UNet
252
+
253
+ self.lora_alpha = lora_alpha
254
+ if init_lora_weights == "kaiming":
255
+ init_lora_weights = True
256
+
257
+ lora_config = LoraConfig(
258
+ r=lora_rank,
259
+ lora_alpha=lora_alpha,
260
+ init_lora_weights=init_lora_weights,
261
+ target_modules=lora_target_modules.split(","),
262
+ )
263
+ model = inject_adapter_in_model(lora_config, model)
264
+
265
+ # Lora pretrained lora weights
266
+ if pretrained_lora_path is not None:
267
+ state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
268
+ if state_dict_converter is not None:
269
+ state_dict = state_dict_converter(state_dict)
270
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
271
+ all_keys = [i for i, _ in model.named_parameters()]
272
+ num_updated_keys = len(all_keys) - len(missing_keys)
273
+ num_unexpected_keys = len(unexpected_keys)
274
+
275
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
276
+
277
+ def get_times(self, prompt,
278
+ image_path=None,
279
+ audio_path=None,
280
+ seq_len=101, # not used while audio_path is not None
281
+ height=720,
282
+ width=720,
283
+ overlap_frame=None,
284
+ num_steps=None,
285
+ negative_prompt=None,
286
+ guidance_scale=None,
287
+ audio_scale=None):
288
+
289
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
290
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
291
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
292
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
293
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
294
+
295
+ if image_path is not None:
296
+ image = Image.open(image_path).convert("RGB")
297
+
298
+ image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
299
+
300
+ _, _, h, w = image.shape
301
+ select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
302
+ image = resize_pad(image, (h, w), select_size)
303
+ image = image * 2.0 - 1.0
304
+ image = image[:, :, None]
305
+
306
+ else:
307
+ image = None
308
+ select_size = [height, width]
309
+ num = self.args.max_tokens * 16 * 16 * 4
310
+ den = select_size[0] * select_size[1]
311
+ L0 = num // den
312
+ diff = (L0 - 1) % 4
313
+ L = L0 - diff
314
+ if L < 1:
315
+ L = 1
316
+ T = (L + 3) // 4
317
+
318
+
319
+ if self.args.random_prefix_frames:
320
+ fixed_frame = overlap_frame
321
+ assert fixed_frame % 4 == 1
322
+ else:
323
+ fixed_frame = 1
324
+ prefix_lat_frame = (3 + fixed_frame) // 4
325
+ first_fixed_frame = 1
326
+
327
+
328
+ audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
329
+
330
+ input_values = np.squeeze(
331
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
332
+ )
333
+ input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
334
+ audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
335
+
336
+ if audio_len < L - first_fixed_frame:
337
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
338
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
339
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
340
+
341
+ seq_len = audio_len
342
+
343
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
344
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
345
+ times += 1
346
+
347
+ return times
348
+
349
+ @torch.no_grad()
350
+ def forward(self, prompt,
351
+ image_path=None,
352
+ audio_path=None,
353
+ seq_len=101, # not used while audio_path is not None
354
+ height=720,
355
+ width=720,
356
+ overlap_frame=None,
357
+ num_steps=None,
358
+ negative_prompt=None,
359
+ guidance_scale=None,
360
+ audio_scale=None):
361
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
362
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
363
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
364
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
365
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
366
+
367
+ if image_path is not None:
368
+ image = Image.open(image_path).convert("RGB")
369
+
370
+ image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
371
+
372
+ _, _, h, w = image.shape
373
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
374
+ image = resize_pad(image, (h, w), select_size)
375
+ image = image * 2.0 - 1.0
376
+ image = image[:, :, None]
377
+
378
+ else:
379
+ image = None
380
+ select_size = [height, width]
381
+ # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
382
+ # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
383
+ # T = (L + 3) // 4 # latent frames
384
+
385
+ # step 1: numerator and denominator as ints
386
+ num = args.max_tokens * 16 * 16 * 4
387
+ den = select_size[0] * select_size[1]
388
+
389
+ # step 2: integer division
390
+ L0 = num // den # exact floor division, no float in sight
391
+
392
+ # step 3: make it ≡ 1 mod 4
393
+ # if L0 % 4 == 1, keep L0;
394
+ # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
395
+ # but ensure the result stays positive.
396
+ diff = (L0 - 1) % 4
397
+ L = L0 - diff
398
+ if L < 1:
399
+ L = 1 # or whatever your minimal frame count is
400
+
401
+ # step 4: latent frames
402
+ T = (L + 3) // 4
403
+
404
+
405
+ if self.args.i2v:
406
+ if self.args.random_prefix_frames:
407
+ fixed_frame = overlap_frame
408
+ assert fixed_frame % 4 == 1
409
+ else:
410
+ fixed_frame = 1
411
+ prefix_lat_frame = (3 + fixed_frame) // 4
412
+ first_fixed_frame = 1
413
+ else:
414
+ fixed_frame = 0
415
+ prefix_lat_frame = 0
416
+ first_fixed_frame = 0
417
+
418
+
419
+ if audio_path is not None and self.args.use_audio:
420
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
421
+ input_values = np.squeeze(
422
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
423
+ )
424
+ input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
425
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
426
+ input_values = input_values.unsqueeze(0)
427
+ # padding audio
428
+ if audio_len < L - first_fixed_frame:
429
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
430
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
431
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
432
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
433
+ with torch.no_grad():
434
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
435
+ audio_embeddings = hidden_states.last_hidden_state
436
+ for mid_hidden_states in hidden_states.hidden_states:
437
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
438
+ seq_len = audio_len
439
+ audio_embeddings = audio_embeddings.squeeze(0)
440
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
441
+ else:
442
+ audio_embeddings = None
443
+
444
+ # loop
445
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
446
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
447
+ times += 1
448
+ video = []
449
+ image_emb = {}
450
+ img_lat = None
451
+ if self.args.i2v:
452
+ self.pipe.load_models_to_device(['vae'])
453
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
454
+
455
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
456
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
457
+ msk[:, :, 1:] = 1
458
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
459
+
460
+ total_iterations = times * num_steps
461
+
462
+ with tqdm(total=total_iterations) as pbar:
463
+ for t in range(times):
464
+ print(f"[{t+1}/{times}]")
465
+ audio_emb = {}
466
+ if t == 0:
467
+ overlap = first_fixed_frame
468
+ else:
469
+ overlap = fixed_frame
470
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
471
+ prefix_overlap = (3 + overlap) // 4
472
+ if audio_embeddings is not None:
473
+ if t == 0:
474
+ audio_tensor = audio_embeddings[
475
+ :min(L - overlap, audio_embeddings.shape[0])
476
+ ]
477
+ else:
478
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
479
+ audio_tensor = audio_embeddings[
480
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
481
+ ]
482
+
483
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
484
+ audio_prefix = audio_tensor[-fixed_frame:]
485
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
486
+ audio_emb["audio_emb"] = audio_tensor
487
+ else:
488
+ audio_prefix = None
489
+ if image is not None and img_lat is None:
490
+ self.pipe.load_models_to_device(['vae'])
491
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
492
+ assert img_lat.shape[2] == prefix_overlap
493
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
494
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
495
+ negative_prompt, num_inference_steps=num_steps,
496
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
497
+ return_latent=True,
498
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B", progress_bar_cmd=pbar)
499
+
500
+ torch.cuda.empty_cache()
501
+ img_lat = None
502
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
503
+
504
+ if t == 0:
505
+ video.append(frames)
506
+ else:
507
+ video.append(frames[:, overlap:])
508
+ video = torch.cat(video, dim=1)
509
+ video = video[:, :ori_audio_len + 1]
510
+
511
+ return video
512
+
513
+
514
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
515
+ snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
516
+ snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
517
+
518
+ import tempfile
519
+
520
+
521
+ set_seed(args.seed)
522
+ seq_len = args.seq_len
523
+ inferpipe = WanInferencePipeline(args)
524
+
525
+
526
+ ADAPTIVE_PROMPT_TEMPLATES = [
527
+ "A claymation video of a person speaking and moving their head accordingly but without moving their hands.",
528
+ "A claymation video of a person speaking and sometimes looking directly to the camera and moving their eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera but with subtle hands movement that complements their speech.",
529
+ "A claymation video of a person speaking and sometimes looking directly to the camera and moving their eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on their movements with dynamic and rhythmic and subtle hand gestures that complement their speech and don't disrupt things if they are holding something with their hands. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence."
530
+ ]
531
+
532
+ def slider_value_change(image_path, audio_path, text, num_steps, session_state):
533
+ return update_generate_button(image_path, audio_path, text, num_steps, session_state), text
534
+
535
+
536
+ def update_generate_button(image_path, audio_path, text, num_steps, session_state):
537
+
538
+ if image_path is None or audio_path is None:
539
+ return gr.update(value="⌚ Zero GPU Required: --")
540
+
541
+ duration_s = get_duration(image_path, audio_path, text, num_steps, session_state, None)
542
+ duration_m = duration_s / 60
543
+
544
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
545
+
546
+ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
547
+
548
+ if image_path is None:
549
+ gr.Info("Step1: Please Provide an Image or Choose from Image Samples")
550
+ print("Step1: Please Provide an Image or Choose from Image Samples")
551
+
552
+ return 0
553
+
554
+ if audio_path is None:
555
+ gr.Info("Step2: Please Provide an Audio or Choose from Audio Samples")
556
+ print("Step2: Please Provide an Audio or Choose from Audio Samples")
557
+
558
+ return 0
559
+
560
+
561
+ audio_chunks = inferpipe.get_times(
562
+ prompt=text,
563
+ image_path=image_path,
564
+ audio_path=audio_path,
565
+ seq_len=args.seq_len,
566
+ num_steps=num_steps
567
+ )
568
+
569
+ if session_id is None:
570
+ session_id = uuid.uuid4().hex
571
+
572
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
573
+
574
+ dirpath = os.path.dirname(image_path)
575
+ basename = os.path.basename(image_path)
576
+ name, ext = os.path.splitext(basename)
577
+
578
+ new_basename = f"clay_{name}{ext}"
579
+ clay_image_path = os.path.join(dirpath, new_basename)
580
+
581
+ if os.path.exists(clay_image_path):
582
+ claymation = 0
583
+ else:
584
+ claymation = flux_inference * 2
585
+
586
+ warmup_s = 15
587
+ last_step_s = 20
588
+ duration_s = (4 * (num_steps - 1) + last_step_s)
589
+
590
+ if audio_chunks > 1:
591
+ duration_s = (duration_s * audio_chunks)
592
+
593
+ duration_s = duration_s + warmup_s + claymation
594
+
595
+ print(f'for {audio_chunks} times and {num_steps} steps, {session_id} is preparing for {duration_s}')
596
+
597
+ return int(duration_s)
598
+
599
+ def preprocess_img(input_image_path, raw_image_path, session_id = None):
600
+
601
+ if session_id is None:
602
+ session_id = uuid.uuid4().hex
603
+
604
+ if input_image_path is None:
605
+ return None, None
606
+
607
+ if raw_image_path == '':
608
+ raw_image_path = input_image_path
609
+
610
+ image = Image.open(raw_image_path).convert("RGB")
611
+
612
+ img_id = uuid.uuid4().hex
613
+
614
+ image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
615
+
616
+ _, _, h, w = image.shape
617
+ select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
618
+ image = resize_pad(image, (h, w), select_size)
619
+ image = image * 2.0 - 1.0
620
+ image = image[:, :, None]
621
+
622
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
623
+
624
+ img_dir = output_dir + '/image'
625
+ os.makedirs(img_dir, exist_ok=True)
626
+ input_img_path = os.path.join(img_dir, f"img_{img_id}.jpg")
627
+
628
+ image = tensor_to_pil(image)
629
+ image.save(input_img_path)
630
+
631
+ return input_img_path, raw_image_path
632
+
633
+ def infer_example(image_path, audio_path, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
634
+
635
+ current_image_size = args.image_sizes_720
636
+ args.image_sizes_720 = [[720, 400]]
637
+ text = ADAPTIVE_PROMPT_TEMPLATES[2]
638
+
639
+ result = infer(image_path, audio_path, text, num_steps, session_id, progress)
640
+
641
+ args.image_sizes_720 = current_image_size
642
+
643
+ return result
644
+
645
+ @spaces.GPU(duration=get_duration)
646
+ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
647
+
648
+ if image_path is None:
649
+
650
+ return None
651
+
652
+ if audio_path is None:
653
+
654
+ return None
655
+
656
+ if session_id is None:
657
+ session_id = uuid.uuid4().hex
658
+
659
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
660
+
661
+ # Decompose the path
662
+ dirpath = os.path.dirname(image_path)
663
+ basename = os.path.basename(image_path) # e.g. "photo.png"
664
+ name, ext = os.path.splitext(basename) # name="photo", ext=".png"
665
+
666
+ # Rebuild with "clay_" prefix
667
+ new_basename = f"clay_{name}{ext}" # "clay_photo.png"
668
+ clay_image_path = os.path.join(dirpath, new_basename)
669
+
670
+ # If the output file already exists, skip inference
671
+ if os.path.exists(clay_image_path):
672
+
673
+ print("using existing image")
674
+
675
+ else:
676
+
677
+ flux_prompt = "in style of omniavatar-claymation"
678
+ raw_image = load_image(image_path)
679
+ w, h = raw_image.size
680
+
681
+ clay_image = flux_pipe(image=raw_image, width=w, height=h, prompt=flux_prompt, negative_prompt=args.negative_prompt, num_inference_steps=flux_inference, true_cfg_scale=2.5).images[0]
682
+ clay_image.save(clay_image_path)
683
+
684
+
685
+ audio_dir = output_dir + '/audio'
686
+ os.makedirs(audio_dir, exist_ok=True)
687
+ if args.silence_duration_s > 0:
688
+ input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
689
+ else:
690
+ input_audio_path = audio_path
691
+ prompt_dir = output_dir + '/prompt'
692
+ os.makedirs(prompt_dir, exist_ok=True)
693
+
694
+ if args.silence_duration_s > 0:
695
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
696
+
697
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
698
+ prompt_path = os.path.join(prompt_dir, f"prompt.txt")
699
+
700
+ video = inferpipe(
701
+ prompt=text,
702
+ image_path=clay_image_path,
703
+ audio_path=input_audio_path,
704
+ seq_len=args.seq_len,
705
+ num_steps=num_steps
706
+ )
707
+
708
+ torch.cuda.empty_cache()
709
+
710
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
711
+ video_paths = save_video_as_grid_and_mp4(video,
712
+ output_dir,
713
+ args.fps,
714
+ prompt=text,
715
+ prompt_path = prompt_path,
716
+ audio_path=tmp2_audio_path if args.use_audio else None,
717
+ prefix=f'result')
718
+
719
+ return video_paths[0]
720
+
721
+ def apply_image(request):
722
+ print('image applied')
723
+ return request, None
724
+
725
+ def apply_audio(request):
726
+ print('audio applied')
727
+ return request
728
+
729
+ def cleanup(request: gr.Request):
730
+
731
+ sid = request.session_hash
732
+ if sid:
733
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
734
+ shutil.rmtree(d1, ignore_errors=True)
735
+
736
+ def start_session(request: gr.Request):
737
+
738
+ return request.session_hash
739
+
740
+ def orientation_changed(session_id, evt: gr.EventData):
741
+
742
+ detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
743
+
744
+ if detail['value'] == "9:16":
745
+ args.image_sizes_720 = [[720, 400]]
746
+ elif detail['value'] == "1:1":
747
+ args.image_sizes_720 = [[720, 720]]
748
+ elif detail['value'] == "16:9":
749
+ args.image_sizes_720 = [[400, 720]]
750
+
751
+ print(f'{session_id} has {args.image_sizes_720} orientation')
752
+
753
+ def clear_raw_image():
754
+ return ''
755
+
756
+ def preprocess_audio_first_nseconds_librosa(audio_path, limit_in_seconds, session_id=None):
757
+
758
+ if not audio_path:
759
+ return None
760
+
761
+ # Robust duration check (librosa changed arg name across versions)
762
+ try:
763
+ dur = librosa.get_duration(path=audio_path)
764
+ except TypeError:
765
+ dur = librosa.get_duration(filename=audio_path)
766
+
767
+ # Small tolerance to avoid re-encoding 4.9999s files
768
+ if dur < 5.0 - 1e-3:
769
+ return audio_path
770
+
771
+ if session_id is None:
772
+ session_id = uuid.uuid4().hex
773
+
774
+ # Where we'll store per-session processed audio
775
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
776
+ audio_dir = os.path.join(output_dir, "audio")
777
+ os.makedirs(audio_dir, exist_ok=True)
778
+
779
+ trimmed_path = os.path.join(audio_dir, f"audio_input_{limit_in_seconds}s.wav")
780
+ sr = getattr(args, "sample_rate", 16000)
781
+
782
+ y, _ = librosa.load(audio_path, sr=sr, mono=True, duration=float(limit_in_seconds))
783
+
784
+ # Save as 16-bit PCM mono WAV
785
+ waveform = torch.from_numpy(y).unsqueeze(0) # [1, num_samples]
786
+ torchaudio.save(
787
+ trimmed_path,
788
+ waveform,
789
+ sr,
790
+ encoding="PCM_S",
791
+ bits_per_sample=16,
792
+ format="wav",
793
+ )
794
+
795
+ return trimmed_path
796
+
797
+
798
+ css = """
799
+ #col-container {
800
+ margin: 0 auto;
801
+ max-width: 1560px;
802
+ }
803
+
804
+ /* editable vs locked, reusing theme variables that adapt to dark/light */
805
+ .stateful textarea:not(:disabled):not([readonly]) {
806
+ color: var(--color-text) !important; /* accent in both modes */
807
+ }
808
+ .stateful textarea:disabled,
809
+ .stateful textarea[readonly]{
810
+ color: var(--body-text-color-subdued) !important; /* subdued in both modes */
811
+ }
812
+ """
813
+
814
+ with gr.Blocks(css=css) as demo:
815
+
816
+ session_state = gr.State()
817
+ demo.load(start_session, outputs=[session_state])
818
+
819
+
820
+ with gr.Column(elem_id="col-container"):
821
+ gr.HTML(
822
+ """
823
+ <div style="text-align: center;">
824
+ <div style="display: flex; justify-content: center;">
825
+ <img src="https://huggingface.co/spaces/alexnasa/OmniAvatar-Clay-Fast/resolve/main/assets/logo-omniavatar.png" alt="Logo">
826
+ </div>
827
+ </div>
828
+ <div style="text-align: center;">
829
+ <p style="font-size:16px; display: inline; margin: 0;">
830
+ <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
831
+ </p>
832
+ <a href="https://huggingface.co/OmniAvatar/OmniAvatar-1.3B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
833
+ [model]
834
+ </a>
835
+ </div>
836
+
837
+ <div style="text-align: center;">
838
+ <strong>HF Space by:</strong>
839
+ <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
840
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
841
+ </a>
842
+ </div>
843
+
844
+ <div style="text-align: center;">
845
+ <p style="font-size:16px; display: inline; margin: 0;">
846
+ If you looking for realism please try the other HF Space:
847
+ </p>
848
+ <a href="https://huggingface.co/spaces/alexnasa/OmniAvatar" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
849
+ <img src="https://img.shields.io/badge/🤗-HF Demo-yellow.svg">
850
+ </a>
851
+ </div>
852
+
853
+ """
854
+ )
855
+
856
+ with gr.Row():
857
+
858
+ with gr.Column(scale=1):
859
+
860
+ image_input = extendedimage(label="Reference Image", type="filepath", height=512)
861
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
862
+ gr.Markdown("*Change the duration limit in Advanced Settings*")
863
+
864
+
865
+ with gr.Column(scale=1):
866
+
867
+ output_video = gr.Video(label="Avatar", height=512)
868
+ num_steps = gr.Slider(8, 50, value=8, step=1, label="Steps")
869
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
870
+ infer_btn = gr.Button("🗿 Clay Me", variant="primary")
871
+ with gr.Accordion("Advanced Settings", open=False):
872
+ raw_img_text = gr.Text(show_label=False, label="", value='', visible=False)
873
+ limit_in_seconds = gr.Slider(5, 180, value=5, step=5, label="Duration")
874
+ text_input = gr.Textbox(label="Prompt", lines=6, value= ADAPTIVE_PROMPT_TEMPLATES[2])
875
+
876
+ with gr.Column(scale=1):
877
+
878
+ cached_examples = gr.Examples(
879
+ examples=[
880
+
881
+ [
882
+ "examples/images/female-003.png",
883
+ "examples/audios/fox.wav",
884
+ 8,
885
+ ''
886
+ ],
887
+
888
+
889
+ [
890
+ "examples/images/male-001.png",
891
+ "examples/audios/ocean.wav",
892
+ 8,
893
+ ''
894
+ ],
895
+
896
+ [
897
+ "examples/images/female-002.png",
898
+ "examples/audios/lion.wav",
899
+ 16,
900
+ ''
901
+ ],
902
+
903
+
904
+ [
905
+ "examples/images/female-009.png",
906
+ "examples/audios/script.wav",
907
+ 8,
908
+ ''
909
+ ],
910
+
911
+ ],
912
+ label="Cached Examples",
913
+ inputs=[image_input, audio_input, num_steps, raw_img_text],
914
+ outputs=[output_video],
915
+ fn=infer_example,
916
+ cache_examples=True
917
+ )
918
+
919
+
920
+ infer_btn.click(
921
+ fn=infer,
922
+ inputs=[image_input, audio_input, text_input, num_steps, session_state],
923
+ outputs=[output_video]
924
+ )
925
+
926
+ image_input.orientation(fn=orientation_changed, inputs=[session_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
927
+ image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
928
+ image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
929
+ image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
930
+ audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
931
+ num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required, text_input])
932
+ audio_input.upload(fn=apply_audio, inputs=[audio_input], outputs=[audio_input]
933
+ ).then(
934
+ fn=preprocess_audio_first_nseconds_librosa,
935
+ inputs=[audio_input, limit_in_seconds, session_state],
936
+ outputs=[audio_input],
937
+ )
938
+
939
+ if __name__ == "__main__":
940
+ demo.unload(cleanup)
941
+ demo.queue()
942
+ demo.launch(ssr_mode=False)
args_config.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config: configs/inference.yaml
2
+
3
+ input_file: examples/infer_samples.txt
4
+ debug: null
5
+ infer: false
6
+ hparams: ''
7
+ dtype: bf16
8
+
9
+ exp_path: pretrained_models/OmniAvatar-1.3B
10
+ text_encoder_path: pretrained_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
11
+ image_encoder_path: None
12
+ dit_path: pretrained_models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
13
+ vae_path: pretrained_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
14
+
15
+ wav2vec_path: pretrained_models/wav2vec2-base-960h
16
+ num_persistent_param_in_dit:
17
+ reload_cfg: true
18
+ sp_size: 1
19
+ seed: 42
20
+ image_sizes_720:
21
+ # - - 400
22
+ # - 720
23
+ # - - 720 commented out due duration needed on HF
24
+ # - 720
25
+ - - 720
26
+ - 400
27
+ image_sizes_1280:
28
+ - - 720
29
+ - 720
30
+ - - 528
31
+ - 960
32
+ - - 960
33
+ - 528
34
+ - - 720
35
+ - 1280
36
+ - - 1280
37
+ - 720
38
+ max_hw: 720
39
+ max_tokens: 40000
40
+ seq_len: 200
41
+ overlap_frame: 13
42
+ guidance_scale: 4.5
43
+ audio_scale: null
44
+ num_steps: 8
45
+ fps: 24
46
+ sample_rate: 16000
47
+ negative_prompt: Vivid color tones, background/camera moving quickly, screen switching,
48
+ subtitles and special effects, mutation, overexposed, static, blurred details, subtitles,
49
+ style, work, painting, image, still, overall grayish, worst quality, low quality,
50
+ JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly
51
+ drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image,
52
+ chaotic background, three legs, crowded background with many people, walking backward
53
+ silence_duration_s: 0.0
54
+ use_fsdp: false
55
+ tea_cache_l1_thresh: 0
56
+ rank: 0
57
+ world_size: 1
58
+ local_rank: 0
59
+ device: cuda
60
+ num_nodes: 1
61
+ i2v: true
62
+ use_audio: true
63
+ random_prefix_frames: true
64
+ model_config:
65
+ in_dim: 33
66
+ audio_hidden_size: 32
67
+ train_architecture: lora
68
+ lora_target_modules: q,k,v,o,ffn.0,ffn.2
69
+ init_lora_weights: kaiming
70
+ lora_rank: 128
71
+ lora_alpha: 64.0
assets/logo-omniavatar.png ADDED
assets/material/pipeline.png ADDED
assets/material/teaser.png ADDED
configs/inference.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 预训练模型路径
2
+ dtype: "bf16"
3
+ text_encoder_path: pretrained_models/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth
4
+ image_encoder_path: None
5
+ dit_path: pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors
6
+ vae_path: pretrained_models/Wan2.1-T2V-14B/Wan2.1_VAE.pth
7
+ wav2vec_path: pretrained_models/wav2vec2-base-960h
8
+ exp_path: pretrained_models/OmniAvatar-14B
9
+ num_persistent_param_in_dit: # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
10
+
11
+ reload_cfg: True
12
+ sp_size: 1
13
+
14
+ # 数据参数
15
+ seed: 42
16
+ image_sizes_720: [[400, 720],
17
+ [720, 720],
18
+ [720, 400]]
19
+ image_sizes_1280: [
20
+ [720, 720],
21
+ [528, 960],
22
+ [960, 528],
23
+ [720, 1280],
24
+ [1280, 720]]
25
+ max_hw: 720 # 720: 480p; 1280: 720p
26
+ max_tokens: 30000
27
+ seq_len: 200
28
+ overlap_frame: 13 # must be 1 + 4*n
29
+ guidance_scale: 4.5
30
+ audio_scale:
31
+ num_steps: 16
32
+ fps: 25
33
+ sample_rate: 16000
34
+ negative_prompt: "Vivid color tones, background/camera moving quickly, screen switching, subtitles and special effects, mutation, overexposed, static, blurred details, subtitles, style, work, painting, image, still, overall grayish, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image, chaotic background, three legs, crowded background with many people, walking backward"
35
+ silence_duration_s: 0.3
36
+ use_fsdp: False
37
+ tea_cache_l1_thresh: 0 # 0.14 The larger this value is, the faster the speed, but the worse the visual quality. TODO check value
configs/inference_1.3B.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 预训练模型路径
2
+ dtype: "bf16"
3
+ text_encoder_path: pretrained_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
4
+ image_encoder_path: None
5
+ dit_path: pretrained_models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
6
+ vae_path: pretrained_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
7
+ wav2vec_path: pretrained_models/wav2vec2-base-960h
8
+ exp_path: pretrained_models/OmniAvatar-1.3B
9
+ num_persistent_param_in_dit: # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
10
+
11
+ reload_cfg: True
12
+ sp_size: 1
13
+
14
+ # 数据参数
15
+ seed: 42
16
+ image_sizes_720: [[400, 720],
17
+ [720, 720],
18
+ [720, 400]]
19
+ image_sizes_1280: [
20
+ [720, 720],
21
+ [528, 960],
22
+ [960, 528],
23
+ [720, 1280],
24
+ [1280, 720]]
25
+ max_hw: 720 # 720: 480p; 1280: 720p
26
+ max_tokens: 30000
27
+ seq_len: 200
28
+ overlap_frame: 13 # must be 1 + 4*n
29
+ guidance_scale: 4.5
30
+ audio_scale:
31
+ num_steps: 10
32
+ fps: 25
33
+ sample_rate: 16000
34
+ negative_prompt: "Vivid color tones, background/camera moving quickly, screen switching, subtitles and special effects, mutation, overexposed, static, blurred details, subtitles, style, work, painting, image, still, overall grayish, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image, chaotic background, three legs, crowded background with many people, walking backward"
35
+ silence_duration_s: 0.3
36
+ use_fsdp: False
37
+ tea_cache_l1_thresh: 0 # 0.14 The larger this value is, the faster the speed, but the worse the visual quality. TODO check value
examples/audios/fox.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4de974b2b0b46ae66a545c04ed98d54e18dc9d67dd8cd8d50aad91dfa978624e
3
+ size 2060268
examples/audios/lion.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76d6b7292da45406ee5b6c7e10dbedcbbb6647a5b0872a3f506419c014696e72
3
+ size 1633964
examples/audios/ocean.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cdb50fd2bc117cbe8e4b37bb3ac7d257511ecba90cb26641a75fe569390c41f
3
+ size 1749164
examples/audios/script.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deb654269051d935d85b6e573805a88466e2f7f791f44b77da16467c5207eeec
3
+ size 259244
examples/images/female-002.png ADDED
examples/images/female-003.png ADDED

Git LFS Details

  • SHA256: 1626385413c07a8ff4931c897c78f3861addc99fb66e12f4150ddfe39d0efca6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.25 MB
examples/images/female-009.png ADDED
examples/images/male-001.png ADDED

Git LFS Details

  • SHA256: f8b88789fe691d92d843327cb22d61ca6628d147bfef7e3e8a020de876db017b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.43 MB
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytest
2
+ diffusers
3
+ torchao
4
+ tqdm
5
+ librosa==0.10.2.post1
6
+ peft>=0.17.0
7
+ transformers==4.52.3
8
+ scipy==1.14.0
9
+ numpy==1.26.4
10
+ ftfy
11
+ einops
12
+ omegaconf
13
+ torchvision
14
+ ninja
15
+ imageio[ffmpeg]
16
+ sentencepiece
17
+ torchaudio
18
+ gradio_extendedimage @ https://github.com/OutofAi/gradio-extendedimage/releases/download/0.0.2/gradio_extendedimage-0.0.2-py3-none-any.whl
scripts/inference.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os, sys
3
+ from glob import glob
4
+ from datetime import datetime
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
6
+ import math
7
+ import random
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from tqdm import tqdm
13
+ from functools import partial
14
+ from omegaconf import OmegaConf
15
+ from argparse import Namespace
16
+
17
+ # # load the one true config you dumped
18
+ # _args_cfg = OmegaConf.load("demo_out/config/args_config.yaml")
19
+ # args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
20
+
21
+ # from OmniAvatar.utils.args_config import set_global_args
22
+
23
+ # set_global_args(args)
24
+
25
+ from OmniAvatar.utils.args_config import parse_args
26
+ args = parse_args()
27
+
28
+ from OmniAvatar.utils.io_utils import load_state_dict
29
+ from peft import LoraConfig, inject_adapter_in_model
30
+ from OmniAvatar.models.model_manager import ModelManager
31
+ from OmniAvatar.wan_video import WanVideoPipeline
32
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
33
+ import torchvision.transforms as TT
34
+ from transformers import Wav2Vec2FeatureExtractor
35
+ import torchvision.transforms as transforms
36
+ import torch.nn.functional as F
37
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
38
+ from huggingface_hub import hf_hub_download
39
+
40
+ def set_seed(seed: int = 42):
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ torch.cuda.manual_seed(seed) # 设置当前GPU
45
+ torch.cuda.manual_seed_all(seed) # 设置所有GPU
46
+
47
+ def read_from_file(p):
48
+ with open(p, "r") as fin:
49
+ for l in fin:
50
+ yield l.strip()
51
+
52
+ def match_size(image_size, h, w):
53
+ ratio_ = 9999
54
+ size_ = 9999
55
+ select_size = None
56
+ for image_s in image_size:
57
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
58
+ size_tmp = abs(max(image_s) - max(w, h))
59
+ if ratio_tmp < ratio_:
60
+ ratio_ = ratio_tmp
61
+ size_ = size_tmp
62
+ select_size = image_s
63
+ if ratio_ == ratio_tmp:
64
+ if size_ == size_tmp:
65
+ select_size = image_s
66
+ return select_size
67
+
68
+ def resize_pad(image, ori_size, tgt_size):
69
+ h, w = ori_size
70
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
71
+ scale_h = int(h * scale_ratio)
72
+ scale_w = int(w * scale_ratio)
73
+
74
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
75
+
76
+ padding_h = tgt_size[0] - scale_h
77
+ padding_w = tgt_size[1] - scale_w
78
+ pad_top = padding_h // 2
79
+ pad_bottom = padding_h - pad_top
80
+ pad_left = padding_w // 2
81
+ pad_right = padding_w - pad_left
82
+
83
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
84
+ return image
85
+
86
+ class WanInferencePipeline(nn.Module):
87
+ def __init__(self, args):
88
+ super().__init__()
89
+ self.args = args
90
+ self.device = torch.device(f"cuda")
91
+ if self.args.dtype=='bf16':
92
+ self.dtype = torch.bfloat16
93
+ elif self.args.dtype=='fp16':
94
+ self.dtype = torch.float16
95
+ else:
96
+ self.dtype = torch.float32
97
+ self.pipe = self.load_model()
98
+ if self.args.i2v:
99
+ chained_trainsforms = []
100
+ chained_trainsforms.append(TT.ToTensor())
101
+ self.transform = TT.Compose(chained_trainsforms)
102
+ if self.args.use_audio:
103
+ from OmniAvatar.models.wav2vec import Wav2VecModel
104
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
105
+ self.args.wav2vec_path
106
+ )
107
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device)
108
+ self.audio_encoder.feature_extractor._freeze_parameters()
109
+
110
+ def load_model(self):
111
+ torch.cuda.set_device(0)
112
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
113
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
114
+ if self.args.train_architecture == 'lora':
115
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
116
+ else:
117
+ resume_path = ckpt_path
118
+
119
+ self.step = 0
120
+
121
+ # Load models
122
+ model_manager = ModelManager(device="cpu", infer=True)
123
+ model_manager.load_models(
124
+ [
125
+ self.args.dit_path.split(","),
126
+ self.args.text_encoder_path,
127
+ self.args.vae_path
128
+ ],
129
+ torch_dtype=self.dtype, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
130
+ device='cpu',
131
+ )
132
+ LORA_REPO_ID = "Kijai/WanVideo_comfy"
133
+ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
134
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
135
+ model_manager.load_lora(causvid_path, lora_alpha=1.0)
136
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
137
+ torch_dtype=self.dtype,
138
+ device=f"cuda",
139
+ use_usp=True if self.args.sp_size > 1 else False,
140
+ infer=True)
141
+ if self.args.train_architecture == "lora":
142
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
143
+ self.add_lora_to_model(
144
+ pipe.denoising_model(),
145
+ lora_rank=self.args.lora_rank,
146
+ lora_alpha=self.args.lora_alpha,
147
+ lora_target_modules=self.args.lora_target_modules,
148
+ init_lora_weights=self.args.init_lora_weights,
149
+ pretrained_lora_path=pretrained_lora_path,
150
+ )
151
+ else:
152
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
153
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
154
+ pipe.requires_grad_(False)
155
+ pipe.eval()
156
+ pipe.enable_vram_management(num_persistent_param_in_dit=self.args.num_persistent_param_in_dit) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
157
+ return pipe
158
+
159
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
160
+ # Add LoRA to UNet
161
+ self.lora_alpha = lora_alpha
162
+ if init_lora_weights == "kaiming":
163
+ init_lora_weights = True
164
+
165
+ lora_config = LoraConfig(
166
+ r=lora_rank,
167
+ lora_alpha=lora_alpha,
168
+ init_lora_weights=init_lora_weights,
169
+ target_modules=lora_target_modules.split(","),
170
+ )
171
+ model = inject_adapter_in_model(lora_config, model)
172
+
173
+ # Lora pretrained lora weights
174
+ if pretrained_lora_path is not None:
175
+ state_dict = load_state_dict(pretrained_lora_path)
176
+ if state_dict_converter is not None:
177
+ state_dict = state_dict_converter(state_dict)
178
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
179
+ all_keys = [i for i, _ in model.named_parameters()]
180
+ num_updated_keys = len(all_keys) - len(missing_keys)
181
+ num_unexpected_keys = len(unexpected_keys)
182
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
183
+
184
+
185
+ def forward(self, prompt,
186
+ image_path=None,
187
+ audio_path=None,
188
+ seq_len=101, # not used while audio_path is not None
189
+ height=720,
190
+ width=720,
191
+ overlap_frame=None,
192
+ num_steps=None,
193
+ negative_prompt=None,
194
+ guidance_scale=None,
195
+ audio_scale=None):
196
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
197
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
198
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
199
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
200
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
201
+
202
+ if image_path is not None:
203
+ from PIL import Image
204
+ image = Image.open(image_path).convert("RGB")
205
+ image = self.transform(image).unsqueeze(0).to(self.device)
206
+ _, _, h, w = image.shape
207
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
208
+ image = resize_pad(image, (h, w), select_size)
209
+ image = image * 2.0 - 1.0
210
+ image = image[:, :, None]
211
+ else:
212
+ image = None
213
+ select_size = [height, width]
214
+ L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
215
+ L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
216
+ T = (L + 3) // 4 # latent frames
217
+
218
+ if self.args.i2v:
219
+ if self.args.random_prefix_frames:
220
+ fixed_frame = overlap_frame
221
+ assert fixed_frame % 4 == 1
222
+ else:
223
+ fixed_frame = 1
224
+ prefix_lat_frame = (3 + fixed_frame) // 4
225
+ first_fixed_frame = 1
226
+ else:
227
+ fixed_frame = 0
228
+ prefix_lat_frame = 0
229
+ first_fixed_frame = 0
230
+
231
+
232
+ if audio_path is not None and self.args.use_audio:
233
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
234
+ input_values = np.squeeze(
235
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
236
+ )
237
+ input_values = torch.from_numpy(input_values).float().to(device=self.device)
238
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
239
+ input_values = input_values.unsqueeze(0)
240
+ # padding audio
241
+ if audio_len < L - first_fixed_frame:
242
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
243
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
244
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
245
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
246
+ with torch.no_grad():
247
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
248
+ audio_embeddings = hidden_states.last_hidden_state
249
+ for mid_hidden_states in hidden_states.hidden_states:
250
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
251
+ seq_len = audio_len
252
+ audio_embeddings = audio_embeddings.squeeze(0)
253
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
254
+ else:
255
+ audio_embeddings = None
256
+
257
+ # loop
258
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
259
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
260
+ times += 1
261
+ video = []
262
+ image_emb = {}
263
+ img_lat = None
264
+ if self.args.i2v:
265
+ self.pipe.load_models_to_device(['vae'])
266
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
267
+
268
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1])
269
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
270
+ msk[:, :, 1:] = 1
271
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
272
+ for t in range(times):
273
+ print(f"[{t+1}/{times}]")
274
+ audio_emb = {}
275
+ if t == 0:
276
+ overlap = first_fixed_frame
277
+ else:
278
+ overlap = fixed_frame
279
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
280
+ prefix_overlap = (3 + overlap) // 4
281
+ if audio_embeddings is not None:
282
+ if t == 0:
283
+ audio_tensor = audio_embeddings[
284
+ :min(L - overlap, audio_embeddings.shape[0])
285
+ ]
286
+ else:
287
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
288
+ audio_tensor = audio_embeddings[
289
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
290
+ ]
291
+
292
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
293
+ audio_prefix = audio_tensor[-fixed_frame:]
294
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
295
+ audio_emb["audio_emb"] = audio_tensor
296
+ else:
297
+ audio_prefix = None
298
+ if image is not None and img_lat is None:
299
+ self.pipe.load_models_to_device(['vae'])
300
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
301
+ assert img_lat.shape[2] == prefix_overlap
302
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1))], dim=2)
303
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
304
+ negative_prompt, num_inference_steps=num_steps,
305
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
306
+ return_latent=True,
307
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
308
+ img_lat = None
309
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2 - 1).permute(0, 2, 1, 3, 4).contiguous()
310
+ if t == 0:
311
+ video.append(frames)
312
+ else:
313
+ video.append(frames[:, overlap:])
314
+ video = torch.cat(video, dim=1)
315
+ video = video[:, :ori_audio_len + 1]
316
+ return video
317
+
318
+
319
+ def main():
320
+
321
+ # os.makedirs("demo_out/config", exist_ok=True)
322
+ # OmegaConf.save(config=OmegaConf.create(vars(args)),
323
+ # f="demo_out/config/args_config.yaml")
324
+ # print("Saved merged args to demo_out/config/args_config.yaml")
325
+
326
+ set_seed(args.seed)
327
+ # laod data
328
+ data_iter = read_from_file(args.input_file)
329
+ exp_name = os.path.basename(args.exp_path)
330
+ seq_len = args.seq_len
331
+
332
+ # Text-to-video
333
+ inferpipe = WanInferencePipeline(args)
334
+
335
+ output_dir = f'demo_out'
336
+
337
+ idx = 0
338
+ text = "A realistic video of a man speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence."
339
+ image_path = "examples/images/0000.jpeg"
340
+ audio_path = "examples/audios/0000.MP3"
341
+ audio_dir = output_dir + '/audio'
342
+ os.makedirs(audio_dir, exist_ok=True)
343
+ if args.silence_duration_s > 0:
344
+ input_audio_path = os.path.join(audio_dir, f"audio_input_{idx:03d}.wav")
345
+ else:
346
+ input_audio_path = audio_path
347
+ prompt_dir = output_dir + '/prompt'
348
+ os.makedirs(prompt_dir, exist_ok=True)
349
+
350
+ if args.silence_duration_s > 0:
351
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
352
+
353
+ video = inferpipe(
354
+ prompt=text,
355
+ image_path=image_path,
356
+ audio_path=input_audio_path,
357
+ seq_len=seq_len
358
+ )
359
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out_{idx:03d}.wav") # 因为第一帧是参考帧,因此需要往前1/25秒
360
+ prompt_path = os.path.join(prompt_dir, f"prompt_{idx:03d}.txt")
361
+
362
+
363
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
364
+ save_video_as_grid_and_mp4(video,
365
+ output_dir,
366
+ args.fps,
367
+ prompt=text,
368
+ prompt_path = prompt_path,
369
+ audio_path=tmp2_audio_path if args.use_audio else None,
370
+ prefix=f'result_{idx:03d}')
371
+
372
+
373
+ class NoPrint:
374
+ def write(self, x):
375
+ pass
376
+ def flush(self):
377
+ pass
378
+
379
+ if __name__ == '__main__':
380
+ if not args.debug:
381
+ if args.local_rank != 0: # 屏蔽除0外的输出
382
+ sys.stdout = NoPrint()
383
+ main()