rahul7star commited on
Commit
3faf8ae
·
verified ·
1 Parent(s): 58fd9c6

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +42 -67
optimization.py CHANGED
@@ -1,29 +1,16 @@
1
- from typing import Any, Callable, ParamSpec
2
- import spaces
3
  import torch
4
- from torch.utils._pytree import tree_map_only
5
- from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
6
- from optimization_utils import capture_component_call, aoti_compile, ZeroGPUCompiledModel, drain_module_parameters
7
-
8
- P = ParamSpec('P')
9
-
10
- TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
11
- TRANSFORMER_DYNAMIC_SHAPES = {'hidden_states': {2: TRANSFORMER_NUM_FRAMES_DIM}}
12
-
13
- INDUCTOR_CONFIGS = {
14
- 'conv_1x1_as_mm': True,
15
- 'epilogue_fusion': False,
16
- 'coordinate_descent_tuning': True,
17
- 'coordinate_descent_check_all_directions': True,
18
- 'max_autotune': True,
19
- 'triton.cudagraphs': True,
20
- }
21
 
22
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
23
  print("[optimize_pipeline_] Starting pipeline optimization")
24
 
25
- quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
26
- print("[optimize_pipeline_] Text encoder quantized")
 
 
 
 
27
 
28
  @spaces.GPU(duration=1500)
29
  def compile_transformer():
@@ -54,9 +41,16 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
54
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
55
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
56
 
57
- print("[compile_transformer] Quantizing transformers with Float8DynamicActivationFloat8WeightConfig")
58
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
59
- quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
 
 
 
 
 
 
 
60
 
61
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
62
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
@@ -68,53 +62,34 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
68
  hidden_states_landscape = hidden_states_transposed
69
  hidden_states_portrait = hidden_states
70
 
71
- print("[compile_transformer] Exporting transformer landscape model")
72
- exported_landscape_1 = torch.export.export(
73
- mod=pipeline.transformer,
74
- args=call.args,
75
- kwargs={**call.kwargs, 'hidden_states': hidden_states_landscape},
76
- dynamic_shapes=dynamic_shapes,
77
- )
78
- torch.cuda.synchronize()
79
-
80
- print("[compile_transformer] Exporting transformer portrait model")
81
- exported_portrait_2 = torch.export.export(
82
- mod=pipeline.transformer_2,
83
- args=call.args,
84
- kwargs={**call.kwargs, 'hidden_states': hidden_states_portrait},
85
- dynamic_shapes=dynamic_shapes,
86
- )
87
- torch.cuda.synchronize()
88
 
89
- print("[compile_transformer] Compiling models with AoT compilation")
90
- compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
91
- compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
 
 
 
 
92
 
93
- compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
94
- compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
95
 
96
- print("[compile_transformer] Compilation done")
97
- return compiled_landscape_1, compiled_landscape_2, compiled_portrait_1, compiled_portrait_2
98
 
99
- cl1, cl2, cp1, cp2 = compile_transformer()
100
-
101
- def combined_transformer_1(*args, **kwargs):
102
- hidden_states: torch.Tensor = kwargs['hidden_states']
103
- if hidden_states.shape[-1] > hidden_states.shape[-2]:
104
- return cl1(*args, **kwargs)
105
- else:
106
- return cp1(*args, **kwargs)
107
-
108
- def combined_transformer_2(*args, **kwargs):
109
- hidden_states: torch.Tensor = kwargs['hidden_states']
110
- if hidden_states.shape[-1] > hidden_states.shape[-2]:
111
- return cl2(*args, **kwargs)
112
- else:
113
- return cp2(*args, **kwargs)
114
 
115
- pipeline.transformer.forward = combined_transformer_1
116
- drain_module_parameters(pipeline.transformer)
117
- pipeline.transformer_2.forward = combined_transformer_2
118
- drain_module_parameters(pipeline.transformer_2)
119
 
 
120
  print("[optimize_pipeline_] Optimization complete")
 
 
 
1
  import torch
2
+ import torchao
3
+ from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
6
  print("[optimize_pipeline_] Starting pipeline optimization")
7
 
8
+ # Quantize and compile text encoder first (weight-only int8 quantization can be replaced by autoquant if preferred)
9
+ pipeline.text_encoder = torchao.autoquant(
10
+ torch.compile(pipeline.text_encoder, mode='max-autotune'),
11
+ qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST # or remove for default quant
12
+ )
13
+ print("[optimize_pipeline_] Text encoder autoquantized and compiled")
14
 
15
  @spaces.GPU(duration=1500)
16
  def compile_transformer():
 
41
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
42
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
43
 
44
+ # Use autoquant + torch.compile on transformers
45
+ print("[compile_transformer] Autoquantizing and compiling transformer")
46
+ compiled_transformer = torchao.autoquant(
47
+ torch.compile(pipeline.transformer, mode='max-autotune'),
48
+ qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
49
+ )
50
+ compiled_transformer_2 = torchao.autoquant(
51
+ torch.compile(pipeline.transformer_2, mode='max-autotune'),
52
+ qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
53
+ )
54
 
55
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
56
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
 
62
  hidden_states_landscape = hidden_states_transposed
63
  hidden_states_portrait = hidden_states
64
 
65
+ # Replace forward with quantized & compiled versions, wrapped for shape dispatch
66
+ def combined_transformer_1(*a, **k):
67
+ if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
68
+ return compiled_transformer(*a, **k)
69
+ else:
70
+ # Swap hidden states for portrait? Use transpose if needed.
71
+ k_mod = k.copy()
72
+ k_mod['hidden_states'] = hidden_states_portrait
73
+ return compiled_transformer(*a, **k_mod)
 
 
 
 
 
 
 
 
74
 
75
+ def combined_transformer_2(*a, **k):
76
+ if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
77
+ return compiled_transformer_2(*a, **k)
78
+ else:
79
+ k_mod = k.copy()
80
+ k_mod['hidden_states'] = hidden_states_portrait
81
+ return compiled_transformer_2(*a, **k_mod)
82
 
83
+ pipeline.transformer.forward = combined_transformer_1
84
+ drain_module_parameters(pipeline.transformer)
85
 
86
+ pipeline.transformer_2.forward = combined_transformer_2
87
+ drain_module_parameters(pipeline.transformer_2)
88
 
89
+ print("[compile_transformer] Transformers autoquantized, compiled, and patched")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Return compiled models for reference if needed
92
+ return compiled_transformer, compiled_transformer_2
 
 
93
 
94
+ cl1, cl2 = compile_transformer()
95
  print("[optimize_pipeline_] Optimization complete")