Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,937 Bytes
1d212f5 805097b 1d212f5 a8689d4 d83fb5a 8750a83 d83fb5a 1d212f5 d83fb5a 1d212f5 d83fb5a dc155d4 58fd9c6 1d212f5 dc155d4 1d212f5 a8689d4 a376574 74e763a a376574 dc03ac9 a376574 250aeeb ab8f891 c102828 74e763a 27a3ad2 250aeeb a376574 a8689d4 1d212f5 8750a83 dc155d4 8750a83 1d212f5 8750a83 1d212f5 8750a83 3faf8ae 1d212f5 8750a83 1d212f5 8750a83 3faf8ae 8750a83 58fd9c6 8750a83 dc155d4 1d212f5 dc155d4 1d212f5 8750a83 6e8eb03 8750a83 dc155d4 8750a83 58fd9c6 a8689d4 6e8eb03 a8689d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import ZeroGPUCompiledModel
from optimization_utils import drain_module_parameters
P = ParamSpec('P')
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
2: TRANSFORMER_NUM_FRAMES_DIM,
},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
# pipeline.load_lora_weights(
# "Kijai/WanVideo_comfy",
# weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
# adapter_name="lightning"
# )
# kwargs_lora = {}
# kwargs_lora["load_into_transformer_2"] = True
# pipeline.load_lora_weights(
# "Kijai/WanVideo_comfy",
# weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
# #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
# adapter_name="lightning_2", **kwargs_lora
# )
# pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
# pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
# pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
# pipeline.unload_lora_weights()
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
adapter_name="lightning"
)
kwargs_lora = {}
kwargs_lora["load_into_transformer_2"] = True
pipeline.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Pusa/Wan21_PusaV1_LoRA_14B_rank512_bf16.safetensors",
#weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
adapter_name="lightning_2", **kwargs_lora
)
pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
pipeline.unload_lora_weights()
with capture_component_call(pipeline, 'transformer') as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
hidden_states: torch.Tensor = call.kwargs['hidden_states']
hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
if hidden_states.shape[-1] > hidden_states.shape[-2]:
hidden_states_landscape = hidden_states
hidden_states_portrait = hidden_states_transposed
else:
hidden_states_landscape = hidden_states_transposed
hidden_states_portrait = hidden_states
exported_landscape_1 = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
dynamic_shapes=dynamic_shapes,
)
exported_portrait_2 = torch.export.export(
mod=pipeline.transformer_2,
args=call.args,
kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
dynamic_shapes=dynamic_shapes,
)
compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
return (
compiled_landscape_1,
compiled_landscape_2,
compiled_portrait_1,
compiled_portrait_2,
)
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
cl1, cl2, cp1, cp2 = compile_transformer()
def combined_transformer_1(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl1(*args, **kwargs)
else:
return cp1(*args, **kwargs)
def combined_transformer_2(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return cl2(*args, **kwargs)
else:
return cp2(*args, **kwargs)
pipeline.transformer.forward = combined_transformer_1
drain_module_parameters(pipeline.transformer)
pipeline.transformer_2.forward = combined_transformer_2
drain_module_parameters(pipeline.transformer_2) |