Spaces:
Running
Running
feat(modeling): simplify abstract_init
Browse files
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -334,7 +334,9 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 334 |
|
| 335 |
# init weights on CPU
|
| 336 |
if load_on_cpu:
|
| 337 |
-
init_fn = jax.jit(
|
|
|
|
|
|
|
| 338 |
else:
|
| 339 |
init_fn = self.init_weights
|
| 340 |
|
|
@@ -343,10 +345,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 343 |
# init the model weights only abstractly, eval_shape will return a pytree
|
| 344 |
# with the structure as weights but without any actual values, this will just contain
|
| 345 |
# the shape information. Weights need to be loaded later.
|
| 346 |
-
|
| 347 |
-
|
|
|
|
| 348 |
else:
|
| 349 |
-
random_params = init_fn(self.key, input_shape)
|
| 350 |
|
| 351 |
# save required_params as set
|
| 352 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
|
|
|
| 334 |
|
| 335 |
# init weights on CPU
|
| 336 |
if load_on_cpu:
|
| 337 |
+
init_fn = jax.jit(
|
| 338 |
+
self.init_weights, static_argnames="input_shape", backend="cpu"
|
| 339 |
+
)
|
| 340 |
else:
|
| 341 |
init_fn = self.init_weights
|
| 342 |
|
|
|
|
| 345 |
# init the model weights only abstractly, eval_shape will return a pytree
|
| 346 |
# with the structure as weights but without any actual values, this will just contain
|
| 347 |
# the shape information. Weights need to be loaded later.
|
| 348 |
+
random_params = jax.eval_shape(
|
| 349 |
+
init_fn, rng=self.key, input_shape=input_shape
|
| 350 |
+
)
|
| 351 |
else:
|
| 352 |
+
random_params = init_fn(rng=self.key, input_shape=input_shape)
|
| 353 |
|
| 354 |
# save required_params as set
|
| 355 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|