rahul7star commited on
Commit
d83fb5a
·
verified ·
1 Parent(s): 3faf8ae

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +19 -0
optimization.py CHANGED
@@ -1,6 +1,25 @@
1
  import torch
2
  import torchao
3
  from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
6
  print("[optimize_pipeline_] Starting pipeline optimization")
 
1
  import torch
2
  import torchao
3
  from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST
4
+ P = ParamSpec('P')
5
+
6
+
7
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
8
+
9
+ TRANSFORMER_DYNAMIC_SHAPES = {
10
+ 'hidden_states': {
11
+ 2: TRANSFORMER_NUM_FRAMES_DIM,
12
+ },
13
+ }
14
+
15
+ INDUCTOR_CONFIGS = {
16
+ 'conv_1x1_as_mm': True,
17
+ 'epilogue_fusion': False,
18
+ 'coordinate_descent_tuning': True,
19
+ 'coordinate_descent_check_all_directions': True,
20
+ 'max_autotune': True,
21
+ 'triton.cudagraphs': True,
22
+ }
23
 
24
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
25
  print("[optimize_pipeline_] Starting pipeline optimization")