rahul7star commited on
Commit
8750a83
·
verified ·
1 Parent(s): 6e8eb03

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +63 -72
optimization.py CHANGED
@@ -1,31 +1,14 @@
1
- import torch
2
- import torchao
3
- from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST
4
- from typing import Any
5
- from typing import Callable
6
- from typing import ParamSpec
7
-
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map_only
11
- from torchao.quantization import quantize_
12
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
- from torchao.quantization import Int8WeightOnlyConfig
14
-
15
- from optimization_utils import capture_component_call
16
- from optimization_utils import aoti_compile
17
- from optimization_utils import ZeroGPUCompiledModel
18
- from optimization_utils import drain_module_parameters
19
- P = ParamSpec('P')
20
 
 
21
 
22
  TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
23
-
24
- TRANSFORMER_DYNAMIC_SHAPES = {
25
- 'hidden_states': {
26
- 2: TRANSFORMER_NUM_FRAMES_DIM,
27
- },
28
- }
29
 
30
  INDUCTOR_CONFIGS = {
31
  'conv_1x1_as_mm': True,
@@ -36,22 +19,14 @@ INDUCTOR_CONFIGS = {
36
  'triton.cudagraphs': True,
37
  }
38
 
39
- from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST
40
-
41
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
42
  print("[optimize_pipeline_] Starting pipeline optimization")
43
 
44
- # --- TEXT ENCODER ---
45
- pipeline.text_encoder = pipeline.text_encoder.cpu()
46
- pipeline.text_encoder = torchao.autoquant(
47
- torch.compile(pipeline.text_encoder, mode='max-autotune'),
48
- qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
49
- ).to("cuda")
50
- print("[optimize_pipeline_] Text encoder autoquantized and compiled")
51
 
52
  @spaces.GPU(duration=1500)
53
  def compile_transformer():
54
- # --- LOAD LORAS ---
55
  print("[compile_transformer] Loading LoRA weights")
56
  pipeline.load_lora_weights(
57
  "DeepBeepMeep/Wan2.2",
@@ -66,64 +41,80 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
66
  )
67
  pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])
68
 
69
- # --- FUSE & UNLOAD ---
70
  print("[compile_transformer] Fusing LoRA weights")
71
  pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
72
  pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
73
  pipeline.unload_lora_weights()
74
 
75
- # --- DUMMY FORWARD ---
76
- print("[compile_transformer] Capturing shapes")
77
  with torch.inference_mode():
78
  with capture_component_call(pipeline, 'transformer') as call:
79
  pipeline(*args, **kwargs)
80
 
81
- hidden_states: torch.Tensor = call.kwargs['hidden_states']
82
- hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
83
 
84
- # --- MOVE TO CUDA BEFORE AUTOQUANT ---
85
- pipeline.transformer = pipeline.transformer.to("cuda")
86
- pipeline.transformer_2 = pipeline.transformer_2.to("cuda")
87
 
88
- # Sanity: Ensure parameters exist before quantization
89
- assert any(p.numel() > 0 for p in pipeline.transformer.parameters()), "Transformer has no params!"
90
- assert any(p.numel() > 0 for p in pipeline.transformer_2.parameters()), "Transformer_2 has no params!"
91
 
92
- # --- AUTOQUANT + COMPILE ---
93
- compiled_transformer = torchao.autoquant(
94
- torch.compile(pipeline.transformer, mode='max-autotune'),
95
- qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
 
 
 
 
 
 
 
 
 
96
  )
97
- compiled_transformer_2 = torchao.autoquant(
98
- torch.compile(pipeline.transformer_2, mode='max-autotune'),
99
- qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
 
 
 
 
 
100
  )
 
 
 
 
 
101
 
102
- # --- PATCH FOR LANDSCAPE/PORTRAIT ---
103
- def combined_transformer_1(*a, **k):
104
- if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
105
- return compiled_transformer(*a, **k)
106
- k_mod = dict(k)
107
- k_mod['hidden_states'] = hidden_states_transposed
108
- return compiled_transformer(*a, **k_mod)
109
 
110
- def combined_transformer_2(*a, **k):
111
- if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
112
- return compiled_transformer_2(*a, **k)
113
- k_mod = dict(k)
114
- k_mod['hidden_states'] = hidden_states_transposed
115
- return compiled_transformer_2(*a, **k_mod)
116
 
117
- pipeline.transformer.forward = combined_transformer_1
118
- pipeline.transformer_2.forward = combined_transformer_2
119
 
120
- # --- NOW drain parameters to save VRAM ---
121
- drain_module_parameters(pipeline.transformer)
122
- drain_module_parameters(pipeline.transformer_2)
 
 
 
123
 
124
- print("[compile_transformer] Transformers autoquantized, compiled, and patched")
125
- return compiled_transformer, compiled_transformer_2
 
 
 
 
126
 
127
- cl1, cl2 = compile_transformer()
128
- print("[optimize_pipeline_] Optimization complete")
 
 
129
 
 
 
1
+ from typing import Any, Callable, ParamSpec
 
 
 
 
 
 
2
  import spaces
3
  import torch
4
  from torch.utils._pytree import tree_map_only
5
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
6
+ from optimization_utils import capture_component_call, aoti_compile, ZeroGPUCompiledModel, drain_module_parameters
 
 
 
 
 
 
 
7
 
8
+ P = ParamSpec('P')
9
 
10
  TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
11
+ TRANSFORMER_DYNAMIC_SHAPES = {'hidden_states': {2: TRANSFORMER_NUM_FRAMES_DIM}}
 
 
 
 
 
12
 
13
  INDUCTOR_CONFIGS = {
14
  'conv_1x1_as_mm': True,
 
19
  'triton.cudagraphs': True,
20
  }
21
 
 
 
22
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
23
  print("[optimize_pipeline_] Starting pipeline optimization")
24
 
25
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
26
+ print("[optimize_pipeline_] Text encoder quantized")
 
 
 
 
 
27
 
28
  @spaces.GPU(duration=1500)
29
  def compile_transformer():
 
30
  print("[compile_transformer] Loading LoRA weights")
31
  pipeline.load_lora_weights(
32
  "DeepBeepMeep/Wan2.2",
 
41
  )
42
  pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])
43
 
 
44
  print("[compile_transformer] Fusing LoRA weights")
45
  pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
46
  pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
47
  pipeline.unload_lora_weights()
48
 
49
+ print("[compile_transformer] Running dummy forward pass to capture component call")
 
50
  with torch.inference_mode():
51
  with capture_component_call(pipeline, 'transformer') as call:
52
  pipeline(*args, **kwargs)
53
 
54
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
55
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
56
 
57
+ print("[compile_transformer] Quantizing transformers with Float8DynamicActivationFloat8WeightConfig")
58
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
59
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
60
 
61
+ hidden_states: torch.Tensor = call.kwargs['hidden_states']
62
+ hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
 
63
 
64
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
65
+ hidden_states_landscape = hidden_states
66
+ hidden_states_portrait = hidden_states_transposed
67
+ else:
68
+ hidden_states_landscape = hidden_states_transposed
69
+ hidden_states_portrait = hidden_states
70
+
71
+ print("[compile_transformer] Exporting transformer landscape model")
72
+ exported_landscape_1 = torch.export.export(
73
+ mod=pipeline.transformer,
74
+ args=call.args,
75
+ kwargs={**call.kwargs, 'hidden_states': hidden_states_landscape},
76
+ dynamic_shapes=dynamic_shapes,
77
  )
78
+ torch.cuda.synchronize()
79
+
80
+ print("[compile_transformer] Exporting transformer portrait model")
81
+ exported_portrait_2 = torch.export.export(
82
+ mod=pipeline.transformer_2,
83
+ args=call.args,
84
+ kwargs={**call.kwargs, 'hidden_states': hidden_states_portrait},
85
+ dynamic_shapes=dynamic_shapes,
86
  )
87
+ torch.cuda.synchronize()
88
+
89
+ print("[compile_transformer] Compiling models with AoT compilation")
90
+ compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
91
+ compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
92
 
93
+ compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
94
+ compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
 
 
 
 
 
95
 
96
+ print("[compile_transformer] Compilation done")
97
+ return compiled_landscape_1, compiled_landscape_2, compiled_portrait_1, compiled_portrait_2
 
 
 
 
98
 
99
+ cl1, cl2, cp1, cp2 = compile_transformer()
 
100
 
101
+ def combined_transformer_1(*args, **kwargs):
102
+ hidden_states: torch.Tensor = kwargs['hidden_states']
103
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
104
+ return cl1(*args, **kwargs)
105
+ else:
106
+ return cp1(*args, **kwargs)
107
 
108
+ def combined_transformer_2(*args, **kwargs):
109
+ hidden_states: torch.Tensor = kwargs['hidden_states']
110
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
111
+ return cl2(*args, **kwargs)
112
+ else:
113
+ return cp2(*args, **kwargs)
114
 
115
+ pipeline.transformer.forward = combined_transformer_1
116
+ drain_module_parameters(pipeline.transformer)
117
+ pipeline.transformer_2.forward = combined_transformer_2
118
+ drain_module_parameters(pipeline.transformer_2)
119
 
120
+ print("[optimize_pipeline_] Optimization complete")