Spaces:
Running
Running
add property to get num params
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
|
@@ -24,6 +24,7 @@ import flax.linen as nn
|
|
| 24 |
import jax
|
| 25 |
import jax.numpy as jnp
|
| 26 |
from flax.core.frozen_dict import FrozenDict, unfreeze
|
|
|
|
| 27 |
from flax.linen import combine_masks, make_causal_mask
|
| 28 |
from flax.linen.attention import dot_product_attention_weights
|
| 29 |
from jax import lax
|
|
@@ -622,6 +623,11 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|
| 622 |
module = self.module_class(config=config, dtype=dtype)
|
| 623 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
|
| 624 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
| 626 |
# init input tensors
|
| 627 |
input_ids = jnp.zeros(input_shape, dtype="i4")
|
|
|
|
| 24 |
import jax
|
| 25 |
import jax.numpy as jnp
|
| 26 |
from flax.core.frozen_dict import FrozenDict, unfreeze
|
| 27 |
+
from flax.traverse_util import flatten_dict
|
| 28 |
from flax.linen import combine_masks, make_causal_mask
|
| 29 |
from flax.linen.attention import dot_product_attention_weights
|
| 30 |
from jax import lax
|
|
|
|
| 623 |
module = self.module_class(config=config, dtype=dtype)
|
| 624 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
|
| 625 |
|
| 626 |
+
@property
|
| 627 |
+
def num_params(self):
|
| 628 |
+
num_params = jax.tree_map(lambda param: param.size, flatten_dict(unfreeze(self.params))).values()
|
| 629 |
+
return sum(list(num_params))
|
| 630 |
+
|
| 631 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
| 632 |
# init input tensors
|
| 633 |
input_ids = jnp.zeros(input_shape, dtype="i4")
|