Spaces:
Running
Running
fix: load from checkpoint
Browse files
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -334,22 +334,19 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 334 |
|
| 335 |
# init weights on CPU
|
| 336 |
if load_on_cpu:
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
)
|
| 340 |
else:
|
| 341 |
-
init_fn = self.
|
| 342 |
|
| 343 |
# randomly initialized parameters
|
|
|
|
| 344 |
if abstract_init:
|
| 345 |
-
#
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
random_params = jax.eval_shape(
|
| 349 |
-
init_fn, rng=self.key, input_shape=input_shape
|
| 350 |
-
)
|
| 351 |
else:
|
| 352 |
-
random_params = init_fn(
|
| 353 |
|
| 354 |
# save required_params as set
|
| 355 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
|
|
|
| 334 |
|
| 335 |
# init weights on CPU
|
| 336 |
if load_on_cpu:
|
| 337 |
+
# init weights on CPU
|
| 338 |
+
init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
|
|
|
|
| 339 |
else:
|
| 340 |
+
init_fn = self.init_weigths
|
| 341 |
|
| 342 |
# randomly initialized parameters
|
| 343 |
+
random_params = self.init_weights(self.key, input_shape)
|
| 344 |
if abstract_init:
|
| 345 |
+
# only set shape and dtype, load parameters separately
|
| 346 |
+
init_fn = partial(init_fn, input_shape=input_shape)
|
| 347 |
+
random_params = jax.eval_shape(init_fn, self.key)
|
|
|
|
|
|
|
|
|
|
| 348 |
else:
|
| 349 |
+
random_params = init_fn(self.key, input_shape)
|
| 350 |
|
| 351 |
# save required_params as set
|
| 352 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|