File size: 5,937 Bytes
1d212f5
 
 
 
 
 
 
805097b
 
 
1d212f5
 
 
 
 
 
 
a8689d4
d83fb5a
8750a83
d83fb5a
1d212f5
d83fb5a
1d212f5
 
 
 
 
 
d83fb5a
 
 
 
 
 
 
 
 
dc155d4
58fd9c6
1d212f5
dc155d4
 
 
1d212f5
a8689d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a376574
 
74e763a
a376574
dc03ac9
 
a376574
 
 
 
250aeeb
ab8f891
c102828
74e763a
27a3ad2
250aeeb
a376574
 
 
 
 
 
a8689d4
1d212f5
 
 
8750a83
 
dc155d4
8750a83
 
1d212f5
8750a83
 
 
 
 
 
 
 
 
 
 
 
1d212f5
8750a83
3faf8ae
1d212f5
8750a83
 
 
1d212f5
8750a83
3faf8ae
8750a83
 
 
58fd9c6
8750a83
 
dc155d4
1d212f5
 
 
 
 
 
dc155d4
1d212f5
8750a83
6e8eb03
8750a83
 
 
 
 
 
dc155d4
8750a83
 
 
 
 
 
58fd9c6
a8689d4
 
6e8eb03
a8689d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import ZeroGPUCompiledModel
from optimization_utils import drain_module_parameters

P = ParamSpec('P')


TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: TRANSFORMER_NUM_FRAMES_DIM,
    },
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):

    @spaces.GPU(duration=1500)
    def compile_transformer():
        
        # pipeline.load_lora_weights(
        #     "Kijai/WanVideo_comfy", 
        #     weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors", 
        #     adapter_name="lightning"
        # )
        # kwargs_lora = {}
        # kwargs_lora["load_into_transformer_2"] = True
        # pipeline.load_lora_weights(
        #     "Kijai/WanVideo_comfy", 
        #      weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors", 
        #     #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors", 
        #     adapter_name="lightning_2", **kwargs_lora
        # )
        # pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])


        # pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
        # pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
        # pipeline.unload_lora_weights()



        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy", 
            weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors", 
            adapter_name="lightning"
        )
        kwargs_lora = {}
        kwargs_lora["load_into_transformer_2"] = True
        pipeline.load_lora_weights(
            "Kijai/WanVideo_comfy", 
             weight_name="Pusa/Wan21_PusaV1_LoRA_14B_rank512_bf16.safetensors", 
            #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors", 
            adapter_name="lightning_2", **kwargs_lora
        )
        pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])


        pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
        pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
        pipeline.unload_lora_weights()
        
        with capture_component_call(pipeline, 'transformer') as call:
            pipeline(*args, **kwargs)
        
        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
        
        hidden_states: torch.Tensor = call.kwargs['hidden_states']
        hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            hidden_states_landscape = hidden_states
            hidden_states_portrait = hidden_states_transposed
        else:
            hidden_states_landscape = hidden_states_transposed
            hidden_states_portrait = hidden_states

        exported_landscape_1 = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
            dynamic_shapes=dynamic_shapes,
        )
        
        exported_portrait_2 = torch.export.export(
            mod=pipeline.transformer_2,
            args=call.args,
            kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
            dynamic_shapes=dynamic_shapes,
        )

        compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
        compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)

        compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
        compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)

        return (
            compiled_landscape_1,
            compiled_landscape_2,
            compiled_portrait_1,
            compiled_portrait_2,
        )

    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
    cl1, cl2, cp1, cp2 = compile_transformer()

    def combined_transformer_1(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl1(*args, **kwargs)
        else:
            return cp1(*args, **kwargs)

    def combined_transformer_2(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl2(*args, **kwargs)
        else:
            return cp2(*args, **kwargs)

    pipeline.transformer.forward = combined_transformer_1
    drain_module_parameters(pipeline.transformer)

    pipeline.transformer_2.forward = combined_transformer_2
    drain_module_parameters(pipeline.transformer_2)