rahul7star commited on
Commit
6e8eb03
·
verified ·
1 Parent(s): 09a6fb7

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +22 -13
optimization.py CHANGED
@@ -36,10 +36,12 @@ INDUCTOR_CONFIGS = {
36
  'triton.cudagraphs': True,
37
  }
38
 
 
 
39
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
40
  print("[optimize_pipeline_] Starting pipeline optimization")
41
 
42
- # Text encoder: move to CPU first, then quantize+compile to avoid early CUDA init
43
  pipeline.text_encoder = pipeline.text_encoder.cpu()
44
  pipeline.text_encoder = torchao.autoquant(
45
  torch.compile(pipeline.text_encoder, mode='max-autotune'),
@@ -49,6 +51,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
49
 
50
  @spaces.GPU(duration=1500)
51
  def compile_transformer():
 
52
  print("[compile_transformer] Loading LoRA weights")
53
  pipeline.load_lora_weights(
54
  "DeepBeepMeep/Wan2.2",
@@ -63,23 +66,30 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
63
  )
64
  pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])
65
 
 
66
  print("[compile_transformer] Fusing LoRA weights")
67
  pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
68
  pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
69
  pipeline.unload_lora_weights()
70
 
71
- print("[compile_transformer] Running dummy forward pass to capture shapes")
 
72
  with torch.inference_mode():
73
  with capture_component_call(pipeline, 'transformer') as call:
74
  pipeline(*args, **kwargs)
75
 
76
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
77
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
78
 
79
- # Now that we're inside GPU context, move and compile transformers
80
  pipeline.transformer = pipeline.transformer.to("cuda")
81
  pipeline.transformer_2 = pipeline.transformer_2.to("cuda")
82
 
 
 
 
 
 
83
  compiled_transformer = torchao.autoquant(
84
  torch.compile(pipeline.transformer, mode='max-autotune'),
85
  qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
@@ -89,28 +99,26 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
89
  qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
90
  )
91
 
92
- # Patch forward with landscape/portrait logic
93
- hidden_states: torch.Tensor = call.kwargs['hidden_states']
94
- hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
95
-
96
  def combined_transformer_1(*a, **k):
97
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
98
  return compiled_transformer(*a, **k)
99
- k_mod = k.copy()
100
  k_mod['hidden_states'] = hidden_states_transposed
101
  return compiled_transformer(*a, **k_mod)
102
 
103
  def combined_transformer_2(*a, **k):
104
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
105
  return compiled_transformer_2(*a, **k)
106
- k_mod = k.copy()
107
  k_mod['hidden_states'] = hidden_states_transposed
108
  return compiled_transformer_2(*a, **k_mod)
109
 
110
  pipeline.transformer.forward = combined_transformer_1
111
- drain_module_parameters(pipeline.transformer)
112
-
113
  pipeline.transformer_2.forward = combined_transformer_2
 
 
 
114
  drain_module_parameters(pipeline.transformer_2)
115
 
116
  print("[compile_transformer] Transformers autoquantized, compiled, and patched")
@@ -118,3 +126,4 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
118
 
119
  cl1, cl2 = compile_transformer()
120
  print("[optimize_pipeline_] Optimization complete")
 
 
36
  'triton.cudagraphs': True,
37
  }
38
 
39
+ from torchao.quantization import DEFAULT_INT4_AUTOQUANT_CLASS_LIST
40
+
41
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
42
  print("[optimize_pipeline_] Starting pipeline optimization")
43
 
44
+ # --- TEXT ENCODER ---
45
  pipeline.text_encoder = pipeline.text_encoder.cpu()
46
  pipeline.text_encoder = torchao.autoquant(
47
  torch.compile(pipeline.text_encoder, mode='max-autotune'),
 
51
 
52
  @spaces.GPU(duration=1500)
53
  def compile_transformer():
54
+ # --- LOAD LORAS ---
55
  print("[compile_transformer] Loading LoRA weights")
56
  pipeline.load_lora_weights(
57
  "DeepBeepMeep/Wan2.2",
 
66
  )
67
  pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1.0, 1.0])
68
 
69
+ # --- FUSE & UNLOAD ---
70
  print("[compile_transformer] Fusing LoRA weights")
71
  pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
72
  pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
73
  pipeline.unload_lora_weights()
74
 
75
+ # --- DUMMY FORWARD ---
76
+ print("[compile_transformer] Capturing shapes")
77
  with torch.inference_mode():
78
  with capture_component_call(pipeline, 'transformer') as call:
79
  pipeline(*args, **kwargs)
80
 
81
+ hidden_states: torch.Tensor = call.kwargs['hidden_states']
82
+ hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
83
 
84
+ # --- MOVE TO CUDA BEFORE AUTOQUANT ---
85
  pipeline.transformer = pipeline.transformer.to("cuda")
86
  pipeline.transformer_2 = pipeline.transformer_2.to("cuda")
87
 
88
+ # Sanity: Ensure parameters exist before quantization
89
+ assert any(p.numel() > 0 for p in pipeline.transformer.parameters()), "Transformer has no params!"
90
+ assert any(p.numel() > 0 for p in pipeline.transformer_2.parameters()), "Transformer_2 has no params!"
91
+
92
+ # --- AUTOQUANT + COMPILE ---
93
  compiled_transformer = torchao.autoquant(
94
  torch.compile(pipeline.transformer, mode='max-autotune'),
95
  qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
 
99
  qtensor_class_list=DEFAULT_INT4_AUTOQUANT_CLASS_LIST
100
  )
101
 
102
+ # --- PATCH FOR LANDSCAPE/PORTRAIT ---
 
 
 
103
  def combined_transformer_1(*a, **k):
104
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
105
  return compiled_transformer(*a, **k)
106
+ k_mod = dict(k)
107
  k_mod['hidden_states'] = hidden_states_transposed
108
  return compiled_transformer(*a, **k_mod)
109
 
110
  def combined_transformer_2(*a, **k):
111
  if k['hidden_states'].shape[-1] > k['hidden_states'].shape[-2]:
112
  return compiled_transformer_2(*a, **k)
113
+ k_mod = dict(k)
114
  k_mod['hidden_states'] = hidden_states_transposed
115
  return compiled_transformer_2(*a, **k_mod)
116
 
117
  pipeline.transformer.forward = combined_transformer_1
 
 
118
  pipeline.transformer_2.forward = combined_transformer_2
119
+
120
+ # --- NOW drain parameters to save VRAM ---
121
+ drain_module_parameters(pipeline.transformer)
122
  drain_module_parameters(pipeline.transformer_2)
123
 
124
  print("[compile_transformer] Transformers autoquantized, compiled, and patched")
 
126
 
127
  cl1, cl2 = compile_transformer()
128
  print("[optimize_pipeline_] Optimization complete")
129
+