rahul7star commited on
Commit
1d212f5
·
verified ·
1 Parent(s): 9318052

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +53 -41
optimization.py CHANGED
@@ -1,14 +1,32 @@
1
- from typing import Any, Callable, ParamSpec
 
 
 
 
 
 
2
  import spaces
3
  import torch
4
  from torch.utils._pytree import tree_map_only
5
- from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
6
- from optimization_utils import capture_component_call, aoti_compile, ZeroGPUCompiledModel, drain_module_parameters
 
 
 
 
 
 
7
 
8
  P = ParamSpec('P')
9
 
 
10
  TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
11
- TRANSFORMER_DYNAMIC_SHAPES = {'hidden_states': {2: TRANSFORMER_NUM_FRAMES_DIM}}
 
 
 
 
 
12
 
13
  INDUCTOR_CONFIGS = {
14
  'conv_1x1_as_mm': True,
@@ -19,48 +37,43 @@ INDUCTOR_CONFIGS = {
19
  'triton.cudagraphs': True,
20
  }
21
 
22
- def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
23
- print("[optimize_pipeline_] Starting pipeline optimization")
24
 
25
- quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
26
- print("[optimize_pipeline_] Text encoder quantized")
27
 
28
  @spaces.GPU(duration=1500)
29
  def compile_transformer():
30
- print("[compile_transformer] Loading LoRA weights")
31
  pipeline.load_lora_weights(
32
- "DeepBeepMeep/Wan2.2",
33
- weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_HIGH_fp16.safetensors",
34
  adapter_name="lightning"
35
  )
 
 
36
  pipeline.load_lora_weights(
37
- "DeepBeepMeep/Wan2.2",
38
- weight_name="loras_accelerators/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
39
- adapter_name="lightning_2",
40
- load_into_transformer_2=True
41
  )
42
- pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])
43
 
44
- print("[compile_transformer] Fusing LoRA weights")
45
- pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
46
- pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
47
- pipeline.unload_lora_weights()
48
-
49
- print("[compile_transformer] Running dummy forward pass to capture component call")
50
- with torch.inference_mode():
51
- with capture_component_call(pipeline, 'transformer') as call:
52
- pipeline(*args, **kwargs)
53
 
 
 
 
 
 
 
 
54
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
55
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
56
 
57
- print("[compile_transformer] Quantizing transformers with Float8DynamicActivationFloat8WeightConfig")
58
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
59
  quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
60
-
61
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
62
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
63
-
64
  if hidden_states.shape[-1] > hidden_states.shape[-2]:
65
  hidden_states_landscape = hidden_states
66
  hidden_states_portrait = hidden_states_transposed
@@ -68,34 +81,34 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
68
  hidden_states_landscape = hidden_states_transposed
69
  hidden_states_portrait = hidden_states
70
 
71
- print("[compile_transformer] Exporting transformer landscape model")
72
  exported_landscape_1 = torch.export.export(
73
  mod=pipeline.transformer,
74
  args=call.args,
75
- kwargs={**call.kwargs, 'hidden_states': hidden_states_landscape},
76
  dynamic_shapes=dynamic_shapes,
77
  )
78
- torch.cuda.synchronize()
79
-
80
- print("[compile_transformer] Exporting transformer portrait model")
81
  exported_portrait_2 = torch.export.export(
82
  mod=pipeline.transformer_2,
83
  args=call.args,
84
- kwargs={**call.kwargs, 'hidden_states': hidden_states_portrait},
85
  dynamic_shapes=dynamic_shapes,
86
  )
87
- torch.cuda.synchronize()
88
 
89
- print("[compile_transformer] Compiling models with AoT compilation")
90
  compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
91
  compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
92
 
93
  compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
94
  compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
95
 
96
- print("[compile_transformer] Compilation done")
97
- return compiled_landscape_1, compiled_landscape_2, compiled_portrait_1, compiled_portrait_2
 
 
 
 
98
 
 
99
  cl1, cl2, cp1, cp2 = compile_transformer()
100
 
101
  def combined_transformer_1(*args, **kwargs):
@@ -114,7 +127,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
114
 
115
  pipeline.transformer.forward = combined_transformer_1
116
  drain_module_parameters(pipeline.transformer)
117
- pipeline.transformer_2.forward = combined_transformer_2
118
- drain_module_parameters(pipeline.transformer_2)
119
 
120
- print("[optimize_pipeline_] Optimization complete")
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+ from torchao.quantization import Int8WeightOnlyConfig
14
+
15
+ from optimization_utils import capture_component_call
16
+ from optimization_utils import aoti_compile
17
+ from optimization_utils import ZeroGPUCompiledModel
18
+ from optimization_utils import drain_module_parameters
19
 
20
  P = ParamSpec('P')
21
 
22
+
23
  TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
24
+
25
+ TRANSFORMER_DYNAMIC_SHAPES = {
26
+ 'hidden_states': {
27
+ 2: TRANSFORMER_NUM_FRAMES_DIM,
28
+ },
29
+ }
30
 
31
  INDUCTOR_CONFIGS = {
32
  'conv_1x1_as_mm': True,
 
37
  'triton.cudagraphs': True,
38
  }
39
 
 
 
40
 
41
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
 
42
 
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
45
+
46
  pipeline.load_lora_weights(
47
+ "Kijai/WanVideo_comfy",
48
+ weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
49
  adapter_name="lightning"
50
  )
51
+ kwargs_lora = {}
52
+ kwargs_lora["load_into_transformer_2"] = True
53
  pipeline.load_lora_weights(
54
+ "Kijai/WanVideo_comfy",
55
+ weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
56
+ #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
57
+ adapter_name="lightning_2", **kwargs_lora
58
  )
59
+ pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
60
 
 
 
 
 
 
 
 
 
 
61
 
62
+ pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
63
+ pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
64
+ pipeline.unload_lora_weights()
65
+
66
+ with capture_component_call(pipeline, 'transformer') as call:
67
+ pipeline(*args, **kwargs)
68
+
69
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
70
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
71
 
 
72
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
73
  quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
74
+
75
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
76
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
 
77
  if hidden_states.shape[-1] > hidden_states.shape[-2]:
78
  hidden_states_landscape = hidden_states
79
  hidden_states_portrait = hidden_states_transposed
 
81
  hidden_states_landscape = hidden_states_transposed
82
  hidden_states_portrait = hidden_states
83
 
 
84
  exported_landscape_1 = torch.export.export(
85
  mod=pipeline.transformer,
86
  args=call.args,
87
+ kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
88
  dynamic_shapes=dynamic_shapes,
89
  )
90
+
 
 
91
  exported_portrait_2 = torch.export.export(
92
  mod=pipeline.transformer_2,
93
  args=call.args,
94
+ kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
95
  dynamic_shapes=dynamic_shapes,
96
  )
 
97
 
 
98
  compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
99
  compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
100
 
101
  compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
102
  compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
103
 
104
+ return (
105
+ compiled_landscape_1,
106
+ compiled_landscape_2,
107
+ compiled_portrait_1,
108
+ compiled_portrait_2,
109
+ )
110
 
111
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
112
  cl1, cl2, cp1, cp2 = compile_transformer()
113
 
114
  def combined_transformer_1(*args, **kwargs):
 
127
 
128
  pipeline.transformer.forward = combined_transformer_1
129
  drain_module_parameters(pipeline.transformer)
 
 
130
 
131
+ pipeline.transformer_2.forward = combined_transformer_2
132
+ drain_module_parameters(pipeline.transformer_2)