linoyts HF Staff commited on
Commit
aa93fb5
·
verified ·
1 Parent(s): 74ee40c

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +67 -0
optimization.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+ import spaces
8
+ import torch
9
+ from torch.utils._pytree import tree_map
10
+
11
+
12
+ P = ParamSpec('P')
13
+
14
+
15
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
16
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
17
+
18
+ TRANSFORMER_DYNAMIC_SHAPES = {
19
+ 'hidden_states': {
20
+ 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
21
+ },
22
+ 'encoder_hidden_states': {
23
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
24
+ },
25
+ 'encoder_hidden_states_mask': {
26
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
27
+ },
28
+ 'image_rotary_emb': ({
29
+ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
30
+ }, {
31
+ 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
32
+ }),
33
+ }
34
+
35
+
36
+ INDUCTOR_CONFIGS = {
37
+ 'conv_1x1_as_mm': True,
38
+ 'epilogue_fusion': False,
39
+ 'coordinate_descent_tuning': True,
40
+ 'coordinate_descent_check_all_directions': True,
41
+ 'max_autotune': True,
42
+ 'triton.cudagraphs': True,
43
+ }
44
+
45
+
46
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
47
+
48
+ @spaces.GPU(duration=1500)
49
+ def compile_transformer():
50
+
51
+ with spaces.aoti_capture(pipeline.transformer) as call:
52
+ pipeline(*args, **kwargs)
53
+
54
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
55
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
56
+
57
+
58
+ exported = torch.export.export(
59
+ mod=pipeline.transformer,
60
+ args=call.args,
61
+ kwargs=call.kwargs,
62
+ dynamic_shapes=dynamic_shapes,
63
+ )
64
+
65
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
66
+
67
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)