artificialguybr commited on
Commit
7d156bd
·
verified ·
1 Parent(s): 88f88e9

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +67 -0
optimization.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from typing import Callable
3
+ from typing import ParamSpec
4
+ from torchao.quantization import quantize_
5
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
6
+ import spaces
7
+ import torch
8
+ from torch.utils._pytree import tree_map
9
+
10
+
11
+ P = ParamSpec('P')
12
+
13
+
14
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
15
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
16
+
17
+ TRANSFORMER_DYNAMIC_SHAPES = {
18
+ 'hidden_states': {
19
+ 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
20
+ },
21
+ 'encoder_hidden_states': {
22
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
23
+ },
24
+ 'encoder_hidden_states_mask': {
25
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
26
+ },
27
+ 'image_rotary_emb': ({
28
+ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
29
+ }, {
30
+ 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
31
+ }),
32
+ }
33
+
34
+
35
+ INDUCTOR_CONFIGS = {
36
+ 'conv_1x1_as_mm': True,
37
+ 'epilogue_fusion': False,
38
+ 'coordinate_descent_tuning': True,
39
+ 'coordinate_descent_check_all_directions': True,
40
+ 'max_autotune': True,
41
+ 'triton.cudagraphs': True,
42
+ }
43
+
44
+
45
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
46
+
47
+ @spaces.GPU(duration=1500)
48
+ def compile_transformer():
49
+
50
+ with spaces.aoti_capture(pipeline.transformer) as call:
51
+ pipeline(*args, **kwargs)
52
+
53
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
54
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
55
+
56
+ # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
57
+
58
+ exported = torch.export.export(
59
+ mod=pipeline.transformer,
60
+ args=call.args,
61
+ kwargs=call.kwargs,
62
+ dynamic_shapes=dynamic_shapes,
63
+ )
64
+
65
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
66
+
67
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)