Spaces:
Running
Running
add gradient checkpointing
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
|
@@ -252,7 +252,8 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 253 |
)
|
| 254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 255 |
-
|
|
|
|
| 256 |
def __call__(
|
| 257 |
self,
|
| 258 |
hidden_states: jnp.ndarray,
|
|
@@ -343,7 +344,8 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 343 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 344 |
)
|
| 345 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 346 |
-
|
|
|
|
| 347 |
def __call__(
|
| 348 |
self,
|
| 349 |
hidden_states: jnp.ndarray,
|
|
|
|
| 252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 253 |
)
|
| 254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 255 |
+
|
| 256 |
+
@nn.remat
|
| 257 |
def __call__(
|
| 258 |
self,
|
| 259 |
hidden_states: jnp.ndarray,
|
|
|
|
| 344 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
| 345 |
)
|
| 346 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
| 347 |
+
|
| 348 |
+
@nn.remat
|
| 349 |
def __call__(
|
| 350 |
self,
|
| 351 |
hidden_states: jnp.ndarray,
|