Yuanshi commited on
Commit
8c2a078
·
verified ·
1 Parent(s): 686f79e

Update src/transformer.py

Browse files
Files changed (1) hide show
  1. src/transformer.py +0 -7
src/transformer.py CHANGED
@@ -7,7 +7,6 @@ from diffusers.models.transformers.transformer_flux import (
7
  FluxTransformer2DModel,
8
  Transformer2DModelOutput,
9
  USE_PEFT_BACKEND,
10
- is_torch_version,
11
  scale_lora_layers,
12
  unscale_lora_layers,
13
  logger,
@@ -155,9 +154,6 @@ def tranformer_forward(
155
 
156
  return custom_forward
157
 
158
- ckpt_kwargs: Dict[str, Any] = (
159
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
160
- )
161
  encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
162
  create_custom_forward(block),
163
  hidden_states,
@@ -204,9 +200,6 @@ def tranformer_forward(
204
 
205
  return custom_forward
206
 
207
- ckpt_kwargs: Dict[str, Any] = (
208
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
209
- )
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  create_custom_forward(block),
212
  hidden_states,
 
7
  FluxTransformer2DModel,
8
  Transformer2DModelOutput,
9
  USE_PEFT_BACKEND,
 
10
  scale_lora_layers,
11
  unscale_lora_layers,
12
  logger,
 
154
 
155
  return custom_forward
156
 
 
 
 
157
  encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
158
  create_custom_forward(block),
159
  hidden_states,
 
200
 
201
  return custom_forward
202
 
 
 
 
203
  hidden_states = torch.utils.checkpoint.checkpoint(
204
  create_custom_forward(block),
205
  hidden_states,