benjamin commited on
Commit
0807ed1
·
verified ·
1 Parent(s): d18f180

Upload FlaxTPUGemma3ForCausalLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TPUGemma3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "attn_logit_softcapping": null,
8
+ "auto_map": {
9
+ "FlaxAutoModelForCausalLM": "modelling_flax_tpu_gemma3.FlaxTPUGemma3ForCausalLM"
10
+ },
11
+ "bos_token_id": 2,
12
+ "cache_implementation": "hybrid",
13
+ "eos_token_id": 1,
14
+ "expand_input_ids": false,
15
+ "expand_input_ids_dict": null,
16
+ "expand_input_ids_maxlen": null,
17
+ "expand_input_ids_vocab_size": null,
18
+ "final_logit_softcapping": null,
19
+ "head_dim": 256,
20
+ "hidden_activation": "gelu_pytorch_tanh",
21
+ "hidden_size": 1152,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 6912,
24
+ "max_position_embeddings": 8192,
25
+ "model_type": "tpu_gemma3",
26
+ "num_attention_heads": 4,
27
+ "num_hidden_layers": 26,
28
+ "num_key_value_heads": 1,
29
+ "pad_token_id": 0,
30
+ "previous_hidden_size": null,
31
+ "project_mode": null,
32
+ "query_pre_attn_scalar": 256,
33
+ "rms_norm_eps": 1e-06,
34
+ "rope_local_base_freq": 10000,
35
+ "rope_scaling": null,
36
+ "rope_theta": 1000000,
37
+ "skip_out_norm": false,
38
+ "sliding_window": 512,
39
+ "sliding_window_pattern": 6,
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.52.3",
42
+ "use_cache": true,
43
+ "vocab_size": 262144
44
+ }
configuration_tpu_gemma3.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TPU Gemma3 model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.modeling_rope_utils import rope_config_validation
5
+
6
+
7
+ class TPUGemma3Config(PretrainedConfig):
8
+ model_type = "tpu_gemma3"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=262_208,
14
+ hidden_size=2304,
15
+ intermediate_size=9216,
16
+ num_hidden_layers=26,
17
+ num_attention_heads=8,
18
+ num_key_value_heads=4,
19
+ head_dim=256,
20
+ hidden_activation="gelu_pytorch_tanh",
21
+ max_position_embeddings=131_072,
22
+ initializer_range=0.02,
23
+ rms_norm_eps=1e-6,
24
+ use_cache=True,
25
+ pad_token_id=0,
26
+ eos_token_id=1,
27
+ bos_token_id=2,
28
+ tie_word_embeddings=True,
29
+ rope_theta=1_000_000.0,
30
+ attention_bias=False,
31
+ attention_dropout=0.0,
32
+ query_pre_attn_scalar=256,
33
+ sliding_window=4096,
34
+ final_logit_softcapping=None,
35
+ attn_logit_softcapping=None,
36
+ cache_implementation="hybrid",
37
+ rope_scaling=None,
38
+ rope_local_base_freq=10_000.0,
39
+ sliding_window_pattern=6,
40
+ expand_input_ids=False, # Transformers-native PyTorch generation support
41
+ expand_input_ids_maxlen=None,
42
+ expand_input_ids_vocab_size=None,
43
+ expand_input_ids_dict=None,
44
+ project_mode=None, # latent projection args
45
+ previous_hidden_size=None,
46
+ skip_out_norm=False,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(
50
+ pad_token_id=pad_token_id,
51
+ bos_token_id=bos_token_id,
52
+ eos_token_id=eos_token_id,
53
+ tie_word_embeddings=tie_word_embeddings,
54
+ **kwargs,
55
+ )
56
+ self.vocab_size = vocab_size
57
+ self.max_position_embeddings = max_position_embeddings
58
+ self.hidden_size = hidden_size
59
+ self.intermediate_size = intermediate_size
60
+ self.num_hidden_layers = num_hidden_layers
61
+ self.num_attention_heads = num_attention_heads
62
+ self.head_dim = head_dim
63
+ self.num_key_value_heads = num_key_value_heads
64
+ self.initializer_range = initializer_range
65
+ self.rms_norm_eps = rms_norm_eps
66
+ self.use_cache = use_cache
67
+ self.rope_theta = rope_theta
68
+ self.attention_bias = attention_bias
69
+ self.attention_dropout = attention_dropout
70
+ self.hidden_activation = hidden_activation
71
+ self.query_pre_attn_scalar = query_pre_attn_scalar
72
+ self.sliding_window = sliding_window
73
+ self.final_logit_softcapping = final_logit_softcapping
74
+ self.attn_logit_softcapping = attn_logit_softcapping
75
+ self.cache_implementation = cache_implementation
76
+
77
+ self.rope_local_base_freq = rope_local_base_freq
78
+ # For configuring HybridCache to work with 5:1 attention pattern
79
+ self.sliding_window_pattern = sliding_window_pattern
80
+ self.rope_scaling = rope_scaling
81
+ rope_config_validation(self)
82
+
83
+ self.expand_input_ids = expand_input_ids
84
+ self.expand_input_ids_maxlen = expand_input_ids_maxlen
85
+ self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
86
+ self.expand_input_ids_dict = expand_input_ids_dict
87
+
88
+ self.project_mode = project_mode
89
+ self.previous_hidden_size = previous_hidden_size
90
+
91
+ self.skip_out_norm = skip_out_norm
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15cb9cde3a6179d743540dbefaaa165becc39ccee99adfa840ed2c5fb657c6f3
3
+ size 3999559506
modelling_flax_tpu_gemma3.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flax TPU Gemma3 model."""
2
+
3
+ from typing import Optional, Tuple
4
+ import copy
5
+
6
+ import flax.linen as nn
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
11
+ from flax.linen import combine_masks, make_causal_mask
12
+ from flax.linen.attention import dot_product_attention_weights
13
+ from flax.linen import partitioning as nn_partitioning
14
+ from flax.traverse_util import flatten_dict, unflatten_dict
15
+ from jax import lax
16
+ from jax.sharding import PartitionSpec as P
17
+
18
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
19
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
20
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
21
+ from .configuration_tpu_gemma3 import TPUGemma3Config
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ _CONFIG_FOR_DOC = "TPUGemma3Config"
27
+ _CHECKPOINT_FOR_DOC = "google/gemma-2-2b"
28
+ _REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2"
29
+
30
+ TPU_GEMMA3_START_DOCSTRING = r"""
31
+
32
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
33
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
34
+ etc.)
35
+
36
+ This model is also a Flax Linen
37
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
38
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
39
+
40
+ Finally, this model supports inherent JAX features such as:
41
+
42
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
43
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
44
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
45
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
46
+
47
+ Parameters:
48
+ config ([`GemmaConfig`]): Model configuration class with all the parameters of the model.
49
+ Initializing with a config file does not load the weights associated with the model, only the
50
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
51
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
52
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or
53
+ `jax.numpy.bfloat16`.
54
+
55
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
56
+ specified all the computation will be performed with the given `dtype`.
57
+
58
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
59
+ parameters.**
60
+
61
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
62
+ [`~FlaxPreTrainedModel.to_bf16`].
63
+ """
64
+
65
+ TPU_GEMMA3_INPUTS_DOCSTRING = r"""
66
+ Args:
67
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
68
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
69
+ it.
70
+
71
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
72
+ [`PreTrainedTokenizer.__call__`] for details.
73
+
74
+ [What are input IDs?](../glossary#input-ids)
75
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
76
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
77
+
78
+ - 1 for tokens that are **not masked**,
79
+ - 0 for tokens that are **masked**.
80
+
81
+ [What are attention masks?](../glossary#attention-mask)
82
+
83
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
84
+ [`PreTrainedTokenizer.__call__`] for details.
85
+
86
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
87
+ `past_key_values`).
88
+
89
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
90
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
91
+ information on the default strategy.
92
+
93
+ - 1 indicates the head is **not masked**,
94
+ - 0 indicates the head is **masked**.
95
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
96
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
97
+ config.n_positions - 1]`.
98
+
99
+ [What are position IDs?](../glossary#position-ids)
100
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
101
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
102
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
103
+ output_attentions (`bool`, *optional*):
104
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
105
+ tensors for more detail.
106
+ output_hidden_states (`bool`, *optional*):
107
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
108
+ more detail.
109
+ return_dict (`bool`, *optional*):
110
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
111
+ """
112
+
113
+ remat = nn_partitioning.remat
114
+
115
+ def create_sinusoidal_positions(num_pos, dim):
116
+ inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2)[: (dim // 2)] / dim))
117
+ freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
118
+
119
+ emb = np.concatenate((freqs, freqs), axis=-1)
120
+ out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
121
+ return jnp.array(out[:, :, :num_pos])
122
+
123
+
124
+ # Copied from transformers.models.llama.modeling_flax_llama.rotate_half
125
+ def rotate_half(tensor):
126
+ """Rotates half the hidden dims of the input."""
127
+ rotate_half_tensor = jnp.concatenate(
128
+ (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
129
+ )
130
+ return rotate_half_tensor
131
+
132
+
133
+ # Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
134
+ def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
135
+ return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
136
+
137
+
138
+ class FlaxTPUGemma3RMSNorm(nn.Module):
139
+ config: TPUGemma3Config
140
+ dim_override: Optional[int] = None
141
+ dtype: jnp.dtype = jnp.float32
142
+ add_in_projection: bool = False
143
+ add_out_projection: bool = False
144
+
145
+ def setup(self):
146
+ self.epsilon = self.config.rms_norm_eps
147
+
148
+ self.weight_is_matrix = False
149
+
150
+ if self.dim_override is not None:
151
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.dim_override)
152
+ else:
153
+ if self.add_in_projection:
154
+ self.in_projection = self.param("in_projection", lambda _, shape: jnp.empty(shape), (self.config.hidden_size, self.config.previous_hidden_size))
155
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.previous_hidden_size)
156
+ elif self.config.project_mode == "wrap":
157
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.previous_hidden_size)
158
+ elif isinstance(self.config.project_mode, str) and self.config.project_mode.startswith("fuse"):
159
+ self.weight = self.param("weight", lambda _, shape: jnp.eye(shape), self.config.hidden_size)
160
+ self.weight_is_matrix = True
161
+ else:
162
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
163
+
164
+ if self.add_out_projection:
165
+ self.out_projection = self.param("out_projection", lambda _, shape: jnp.empty(shape), (self.config.previous_hidden_size, self.config.hidden_size))
166
+
167
+ def __call__(self, hidden_states):
168
+ if self.add_in_projection:
169
+ hidden_states = hidden_states @ self.in_projection
170
+
171
+ variance = jnp.asarray(hidden_states, dtype=jnp.float32)
172
+ variance = jnp.power(variance, 2)
173
+ variance = variance.mean(-1, keepdims=True)
174
+ # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
175
+ hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
176
+
177
+ if self.weight_is_matrix:
178
+ hidden_states = jnp.asarray(hidden_states, dtype=self.dtype) @ self.weight
179
+ else:
180
+ hidden_states = (1 + self.weight) * jnp.asarray(hidden_states, dtype=self.dtype)
181
+
182
+ if self.add_out_projection:
183
+ hidden_states = hidden_states @ self.out_projection
184
+
185
+ return hidden_states
186
+
187
+
188
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Gemma3
189
+ class FlaxTPUGemma3RotaryEmbedding(nn.Module):
190
+ config: TPUGemma3Config
191
+ dtype: jnp.dtype = jnp.float32
192
+
193
+ # Ignore copy
194
+ def setup(self):
195
+ head_dim = self.config.head_dim
196
+ self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)
197
+
198
+ def __call__(self, position_ids):
199
+ sincos = self.sincos[position_ids]
200
+ sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)
201
+
202
+ return sin_pos, cos_pos
203
+
204
+
205
+ class FlaxTPUGemma3Attention(nn.Module):
206
+ config: TPUGemma3Config
207
+ layer_idx: int
208
+ dtype: jnp.dtype = jnp.float32
209
+ causal: bool = True
210
+ is_cross_attention: bool = False
211
+
212
+ def setup(self):
213
+ self.is_sliding = bool((self.layer_idx + 1) % self.config.sliding_window_pattern)
214
+ self.sliding_window = self.config.sliding_window if self.is_sliding else None
215
+
216
+ config = self.config
217
+ if self.config.project_mode == "wrap":
218
+ self.embed_dim = config.previous_hidden_size
219
+ else:
220
+ self.embed_dim = config.hidden_size
221
+
222
+ self.num_heads = config.num_attention_heads
223
+ self.head_dim = config.head_dim
224
+
225
+ # otherwise we would manually have to scale attn weights
226
+ assert config.query_pre_attn_scalar == config.head_dim
227
+
228
+ self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
229
+
230
+ self.num_key_value_heads = config.num_key_value_heads
231
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
232
+
233
+ kernel = jax.nn.initializers.normal(self.config.initializer_range)
234
+ self.q_proj = nn.Dense(
235
+ self.num_heads * self.head_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel
236
+ )
237
+ self.k_proj = nn.Dense(
238
+ self.num_key_value_heads * self.head_dim,
239
+ use_bias=config.attention_bias,
240
+ dtype=self.dtype,
241
+ kernel_init=kernel,
242
+ )
243
+ self.v_proj = nn.Dense(
244
+ self.num_key_value_heads * self.head_dim,
245
+ use_bias=config.attention_bias,
246
+ dtype=self.dtype,
247
+ kernel_init=kernel,
248
+ )
249
+ self.q_norm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, dim_override=self.head_dim)
250
+ self.k_norm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, dim_override=self.head_dim)
251
+
252
+ self.o_proj = nn.Dense(self.embed_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel)
253
+
254
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
255
+
256
+ def _split_heads(self, hidden_states, num_heads):
257
+ return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
258
+
259
+ def _merge_heads(self, hidden_states):
260
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads * self.head_dim,))
261
+
262
+ @nn.compact
263
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
264
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
265
+ """
266
+ This function takes projected key, value states from a single input token and concatenates the states to cached
267
+ states from previous steps. This function is slighly adapted from the official Flax repository:
268
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
269
+ """
270
+ # detect if we're initializing by absence of existing cache data.
271
+ is_initialized = self.has_variable("cache", "cached_key")
272
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
273
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
274
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
275
+
276
+ if is_initialized:
277
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
278
+ # update key, value caches with our new 1d spatial slices
279
+ cur_index = cache_index.value
280
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
281
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
282
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
283
+ cached_key.value = key
284
+ cached_value.value = value
285
+ num_updated_cache_vectors = query.shape[1]
286
+ cache_index.value = cache_index.value + num_updated_cache_vectors
287
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
288
+ pad_mask = jnp.broadcast_to(
289
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
290
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
291
+ )
292
+ attention_mask = combine_masks(pad_mask, attention_mask)
293
+ return key, value, attention_mask
294
+
295
+ def __call__(
296
+ self,
297
+ hidden_states,
298
+ position_embeddings,
299
+ attention_mask,
300
+ position_ids,
301
+ deterministic: bool = True,
302
+ init_cache: bool = False,
303
+ output_attentions: bool = False,
304
+ ):
305
+ raw_query = self.q_proj(hidden_states)
306
+ raw_key = self.k_proj(hidden_states)
307
+ raw_value = self.v_proj(hidden_states)
308
+
309
+ query = self._split_heads(raw_query, self.num_heads)
310
+ key = self._split_heads(raw_key, self.num_key_value_heads)
311
+ value = self._split_heads(raw_value, self.num_key_value_heads)
312
+
313
+ query = self.q_norm(query)
314
+ key = self.k_norm(key)
315
+
316
+ sin, cos = position_embeddings
317
+
318
+ key = jnp.asarray(apply_rotary_pos_emb(key, sin, cos), dtype=self.dtype)
319
+ query = jnp.asarray(apply_rotary_pos_emb(query, sin, cos), dtype=self.dtype)
320
+
321
+ query_length, key_length = query.shape[1], key.shape[1]
322
+
323
+ if self.has_variable("cache", "cached_key"):
324
+ mask_shift = self.variables["cache"]["cache_index"]
325
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
326
+ causal_mask = lax.dynamic_slice(
327
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
328
+ )
329
+ else:
330
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
331
+
332
+ batch_size = hidden_states.shape[0]
333
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
334
+
335
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
336
+ attention_mask = combine_masks(attention_mask, causal_mask)
337
+
338
+ if self.sliding_window is not None:
339
+ min_dtype = jnp.finfo(hidden_states.dtype).min
340
+ sliding_window_mask = jnp.tril(
341
+ jnp.ones_like(attention_mask, dtype=bool), k=-self.sliding_window
342
+ )
343
+ attention_mask = jnp.where(sliding_window_mask, min_dtype, attention_mask)
344
+ if attention_mask.shape[-1] <= 1: # when decoding
345
+ attention_mask = attention_mask[:, :, :, -self.sliding_window :]
346
+
347
+ dropout_rng = None
348
+ if not deterministic and self.config.attention_dropout > 0.0:
349
+ dropout_rng = self.make_rng("dropout")
350
+
351
+ # During fast autoregressive decoding, we feed one position at a time,
352
+ # and cache the keys and values step by step.
353
+ if self.has_variable("cache", "cached_key") or init_cache:
354
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
355
+
356
+ # transform boolean mask into float mask
357
+ attention_bias = lax.select(
358
+ attention_mask > 0,
359
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
360
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
361
+ )
362
+
363
+ key = jnp.repeat(key, repeats=self.num_key_value_groups, axis=2)
364
+ value = jnp.repeat(value, repeats=self.num_key_value_groups, axis=2)
365
+
366
+ # usual dot product attention
367
+ attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
368
+ attn_weights = dot_product_attention_weights(
369
+ query,
370
+ key,
371
+ bias=attention_bias,
372
+ dropout_rng=dropout_rng,
373
+ dropout_rate=self.config.attention_dropout,
374
+ deterministic=deterministic,
375
+ dtype=attention_dtype,
376
+ )
377
+
378
+ if self.config.attn_logit_softcapping is not None:
379
+ attn_weights = attn_weights / self.config.attn_logit_softcapping
380
+ attn_weights = jnp.tanh(attn_weights)
381
+ attn_weights = attn_weights * self.config.attn_logit_softcapping
382
+
383
+ if self.attention_softmax_in_fp32:
384
+ attn_weights = attn_weights.astype(self.dtype)
385
+
386
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
387
+ attn_output = self._merge_heads(attn_output)
388
+ attn_output = self.o_proj(attn_output)
389
+
390
+ outputs = (attn_output, (raw_query, raw_key, raw_value)) if output_attentions else (attn_output,)
391
+ return outputs
392
+
393
+
394
+ class FlaxTPUGemma3MLP(nn.Module):
395
+ config: TPUGemma3Config
396
+ dtype: jnp.dtype = jnp.float32
397
+
398
+ def setup(self):
399
+ if self.config.project_mode == "wrap":
400
+ embed_dim = self.config.previous_hidden_size
401
+ else:
402
+ embed_dim = self.config.hidden_size
403
+
404
+ inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
405
+
406
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
407
+ if self.config.hidden_activation is None:
408
+ logger.warning_once(
409
+ "Gemma3's activation function should be approximate GeLU and not exact GeLU. "
410
+ "Changing the activation function to `gelu_pytorch_tanh`."
411
+ f"if you want to use the legacy `{self.config.hidden_act}`, "
412
+ f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` "
413
+ " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
414
+ )
415
+ hidden_activation = "gelu_pytorch_tanh"
416
+ else:
417
+ hidden_activation = self.config.hidden_activation
418
+ self.act = ACT2FN[hidden_activation]
419
+
420
+ self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
421
+ self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
422
+ self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
423
+
424
+ def __call__(self, hidden_states):
425
+ up_proj_states = self.up_proj(hidden_states)
426
+ gate_states = self.act(self.gate_proj(hidden_states))
427
+
428
+ hidden_states = self.down_proj(up_proj_states * gate_states)
429
+ return hidden_states
430
+
431
+
432
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Gemma3
433
+ class FlaxTPUGemma3DecoderLayer(nn.Module):
434
+ config: TPUGemma3Config
435
+ layer_idx: int
436
+ dtype: jnp.dtype = jnp.float32
437
+
438
+ def setup(self):
439
+ self.input_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_in_projection=self.config.project_mode == "wrap")
440
+ self.self_attn = FlaxTPUGemma3Attention(self.config, self.layer_idx, dtype=self.dtype)
441
+ self.pre_feedforward_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_in_projection=self.config.project_mode == "wrap")
442
+ self.post_feedforward_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_out_projection=self.config.project_mode == "wrap")
443
+ self.post_attention_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_out_projection=self.config.project_mode == "wrap")
444
+ self.mlp = FlaxTPUGemma3MLP(self.config, dtype=self.dtype)
445
+
446
+ def __call__(
447
+ self,
448
+ hidden_states,
449
+ position_embeddings_global,
450
+ position_embeddings_local,
451
+ attention_mask=None,
452
+ position_ids=None,
453
+ deterministic: bool = True,
454
+ init_cache: bool = False,
455
+ output_attentions: bool = False,
456
+ ):
457
+ mesh = getattr(self.config, "mesh", None)
458
+ if mesh is not None:
459
+ hidden_states = jax.lax.with_sharding_constraint(
460
+ hidden_states, jax.sharding.NamedSharding(mesh, P("data", None, "model"))
461
+ )
462
+ residual = hidden_states
463
+ hidden_states = self.input_layernorm(hidden_states)
464
+
465
+ # apply global RoPE to non-sliding layer only
466
+ if self.self_attn.is_sliding:
467
+ position_embeddings = position_embeddings_local
468
+ else:
469
+ position_embeddings = position_embeddings_global
470
+
471
+ outputs = self.self_attn(
472
+ hidden_states,
473
+ position_embeddings,
474
+ attention_mask=attention_mask,
475
+ position_ids=position_ids,
476
+ deterministic=deterministic,
477
+ init_cache=init_cache,
478
+ output_attentions=output_attentions,
479
+ )
480
+ # residual connection
481
+ attn_output = self.post_attention_layernorm(outputs[0])
482
+ hidden_states = residual + attn_output
483
+
484
+ residual = hidden_states
485
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
486
+ hidden_states = self.mlp(hidden_states)
487
+ mlp_output = self.post_feedforward_layernorm(hidden_states)
488
+ # residual connection
489
+ hidden_states = residual + mlp_output
490
+
491
+ return (hidden_states, attn_output, mlp_output)
492
+
493
+
494
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Gemma3, GPT_NEO->Gemma3, transformer->model
495
+ class FlaxTPUGemma3PreTrainedModel(FlaxPreTrainedModel):
496
+ """
497
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
498
+ models.
499
+ """
500
+
501
+ config_class = TPUGemma3Config
502
+ base_model_prefix = "model"
503
+ module_class: nn.Module = None
504
+
505
+ def __init__(
506
+ self,
507
+ config: TPUGemma3Config,
508
+ input_shape: Tuple = (1, 1),
509
+ seed: int = 0,
510
+ dtype: jnp.dtype = jnp.float32,
511
+ _do_init: bool = True,
512
+ gradient_checkpointing: bool = False,
513
+ **kwargs,
514
+ ):
515
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
516
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
517
+
518
+ def enable_gradient_checkpointing(self):
519
+ self._module = self.module_class(
520
+ config=self.config,
521
+ dtype=self.dtype,
522
+ gradient_checkpointing=True,
523
+ )
524
+
525
+ @classmethod
526
+ def can_generate(cls) -> bool:
527
+ # disable generation, handled separately
528
+ # this is convenient since GenerationConfig.from_model_config(config) needs a pickleable config
529
+ return False
530
+
531
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
532
+ # init input tensors
533
+ input_ids = jnp.zeros(input_shape, dtype="i4")
534
+ attention_mask = jnp.ones_like(input_ids)
535
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
536
+ params_rng, dropout_rng = jax.random.split(rng)
537
+ rngs = {"params": params_rng, "dropout": dropout_rng}
538
+
539
+ random_params = self.module.init(rngs, input_ids, None, attention_mask, position_ids, return_dict=False)["params"]
540
+
541
+ if params is not None:
542
+ random_params = flatten_dict(unfreeze(random_params))
543
+ params = flatten_dict(unfreeze(params))
544
+ for missing_key in self._missing_keys:
545
+ params[missing_key] = random_params[missing_key]
546
+ self._missing_keys = set()
547
+ return freeze(unflatten_dict(params))
548
+ else:
549
+ return random_params
550
+
551
+ def init_cache(self, batch_size, max_length):
552
+ r"""
553
+ Args:
554
+ batch_size (`int`):
555
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
556
+ max_length (`int`):
557
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
558
+ cache.
559
+ """
560
+ # init input variables to retrieve cache
561
+ input_ids = jnp.ones((batch_size, max_length))
562
+ attention_mask = jnp.ones_like(input_ids)
563
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
564
+
565
+ init_variables = self.module.init(
566
+ jax.random.PRNGKey(0), input_ids, None, attention_mask, position_ids, return_dict=False, init_cache=True
567
+ )
568
+ return unfreeze(init_variables["cache"])
569
+
570
+ @add_start_docstrings_to_model_forward(TPU_GEMMA3_INPUTS_DOCSTRING)
571
+ def __call__(
572
+ self,
573
+ input_ids,
574
+ inputs_embeds=None,
575
+ attention_mask=None,
576
+ position_ids=None,
577
+ params: dict = None,
578
+ past_key_values: dict = None,
579
+ dropout_rng: jax.random.PRNGKey = None,
580
+ train: bool = False,
581
+ output_attentions: Optional[bool] = None,
582
+ output_hidden_states: Optional[bool] = None,
583
+ return_dict: Optional[bool] = None,
584
+ ):
585
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
586
+ output_hidden_states = (
587
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
588
+ )
589
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
590
+
591
+ if input_ids is not None:
592
+ batch_size, sequence_length = input_ids.shape
593
+ else:
594
+ batch_size, sequence_length, _ = inputs_embeds.shape
595
+
596
+ if position_ids is None:
597
+ if past_key_values is not None:
598
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
599
+
600
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
601
+
602
+ if attention_mask is None:
603
+ attention_mask = jnp.ones((batch_size, sequence_length))
604
+
605
+ # Handle any PRNG if needed
606
+ rngs = {}
607
+ if dropout_rng is not None:
608
+ rngs["dropout"] = dropout_rng
609
+
610
+ inputs = {"params": params or self.params}
611
+
612
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGemma3Attention module
613
+ if past_key_values:
614
+ inputs["cache"] = past_key_values
615
+ mutable = ["cache"]
616
+ else:
617
+ mutable = False
618
+
619
+ outputs = self.module.apply(
620
+ inputs,
621
+ jnp.array(input_ids, dtype="i4") if input_ids is not None else None,
622
+ inputs_embeds if inputs_embeds is not None else None,
623
+ jnp.array(attention_mask, dtype="i4"),
624
+ jnp.array(position_ids, dtype="i4"),
625
+ not train,
626
+ False,
627
+ output_attentions,
628
+ output_hidden_states,
629
+ return_dict,
630
+ rngs=rngs,
631
+ mutable=mutable,
632
+ )
633
+
634
+ # add updated cache to model output
635
+ if past_key_values is not None and return_dict:
636
+ outputs, past_key_values = outputs
637
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
638
+ return outputs
639
+ elif past_key_values is not None and not return_dict:
640
+ outputs, past_key_values = outputs
641
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
642
+
643
+ return outputs
644
+
645
+
646
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Gemma3
647
+ class FlaxTPUGemma3LayerCollection(nn.Module):
648
+ config: TPUGemma3Config
649
+ dtype: jnp.dtype = jnp.float32
650
+ gradient_checkpointing: bool = False
651
+
652
+ def setup(self):
653
+ self.rotary_emb = FlaxTPUGemma3RotaryEmbedding(config=self.config)
654
+
655
+ mesh = getattr(self.config, "mesh", None)
656
+ del self.config.mesh
657
+ local_config = copy.deepcopy(self.config)
658
+ if mesh is not None:
659
+ self.config.mesh = mesh
660
+
661
+ local_config.rope_theta = self.config.rope_local_base_freq
662
+ local_config.rope_scaling = {"rope_type": "default"}
663
+ self.rotary_emb_local = FlaxTPUGemma3RotaryEmbedding(config=local_config)
664
+
665
+ if self.gradient_checkpointing:
666
+ FlaxTPUGemma3DecoderCheckpointLayer = remat(FlaxTPUGemma3DecoderLayer, static_argnums=(3, 4, 5))
667
+ self.blocks = [
668
+ FlaxTPUGemma3DecoderCheckpointLayer(self.config, layer_idx, dtype=self.dtype, name=str(layer_idx))
669
+ for layer_idx in range(self.config.num_hidden_layers)
670
+ ]
671
+ else:
672
+ self.blocks = [
673
+ FlaxTPUGemma3DecoderLayer(self.config, layer_idx, dtype=self.dtype, name=str(layer_idx))
674
+ for layer_idx in range(self.config.num_hidden_layers)
675
+ ]
676
+
677
+ def __call__(
678
+ self,
679
+ hidden_states,
680
+ attention_mask=None,
681
+ position_ids=None,
682
+ deterministic: bool = True,
683
+ init_cache: bool = False,
684
+ output_attentions: bool = False,
685
+ output_hidden_states: bool = False,
686
+ return_dict: bool = False,
687
+ ):
688
+ all_attentions = () if output_attentions else None
689
+ all_hidden_states = [(), ()] if output_hidden_states else None
690
+
691
+ position_embeddings_global = self.rotary_emb(position_ids)
692
+ position_embeddings_local = self.rotary_emb_local(position_ids)
693
+
694
+ if output_hidden_states:
695
+ all_hidden_states[0] += (hidden_states,)
696
+ all_hidden_states[1] += (hidden_states,)
697
+
698
+ for block_idx, block in enumerate(self.blocks):
699
+ layer_outputs = block(
700
+ hidden_states,
701
+ position_embeddings_global,
702
+ position_embeddings_local,
703
+ attention_mask,
704
+ position_ids,
705
+ deterministic,
706
+ init_cache,
707
+ output_attentions,
708
+ )
709
+ hidden_states = layer_outputs[0]
710
+
711
+ if output_hidden_states:
712
+ # last block is followed by norm - added later
713
+ if block_idx != len(self.blocks) - 1:
714
+ all_hidden_states[0] += (hidden_states,)
715
+
716
+ all_hidden_states[1] += layer_outputs[1:]
717
+
718
+ if output_attentions:
719
+ raise NotImplementedError("Attention outputs are not implemented for TPUGemma3 (with projections).")
720
+
721
+ # this contains possible `None` values - `FlaxGemma3Module` will filter them out
722
+ outputs = (hidden_states, all_hidden_states, all_attentions)
723
+
724
+ return outputs
725
+
726
+
727
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Gemma3
728
+ class FlaxTPUGemma3Module(nn.Module):
729
+ config: TPUGemma3Config
730
+ dtype: jnp.dtype = jnp.float32
731
+ gradient_checkpointing: bool = False
732
+
733
+ def setup(self):
734
+ if self.config.project_mode == "wrap":
735
+ self.hidden_size = self.config.previous_hidden_size
736
+ else:
737
+ self.hidden_size = self.config.hidden_size
738
+
739
+ embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
740
+
741
+ self.embed_tokens = nn.Embed(
742
+ self.config.vocab_size,
743
+ self.hidden_size,
744
+ embedding_init=embedding_init,
745
+ dtype=self.dtype,
746
+ )
747
+ self.layers = FlaxTPUGemma3LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
748
+ self.norm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_in_projection=self.config.project_mode == "wrap", add_out_projection=False)
749
+
750
+ if self.config.project_mode == "wrap":
751
+ self.embedding_projection = self.param("embedding_projection", lambda _, shape: jnp.empty(shape), (self.config.previous_hidden_size, self.config.hidden_size))
752
+
753
+ def embed(
754
+ self,
755
+ input_ids,
756
+ ):
757
+ inputs_embeds = self.embed_tokens(input_ids.astype("i4"))
758
+
759
+ if self.config.project_mode is not None:
760
+ scaler = self.config.previous_hidden_size ** 0.5
761
+ else:
762
+ scaler = self.config.hidden_size ** 0.5
763
+
764
+ inputs_embeds = inputs_embeds * scaler
765
+
766
+ if self.config.project_mode == "wrap":
767
+ inputs_embeds = inputs_embeds @ self.embedding_projection
768
+
769
+ return inputs_embeds
770
+
771
+ # Ignore copy
772
+ def __call__(
773
+ self,
774
+ input_ids,
775
+ inputs_embeds=None,
776
+ attention_mask=None,
777
+ position_ids=None,
778
+ deterministic=True,
779
+ init_cache: bool = False,
780
+ output_attentions: bool = False,
781
+ output_hidden_states: bool = False,
782
+ return_dict: bool = True,
783
+ ):
784
+ if inputs_embeds is None:
785
+ inputs_embeds = self.embed(input_ids)
786
+
787
+ outputs = self.layers(
788
+ inputs_embeds,
789
+ position_ids=position_ids,
790
+ attention_mask=attention_mask,
791
+ deterministic=deterministic,
792
+ init_cache=init_cache,
793
+ output_attentions=output_attentions,
794
+ output_hidden_states=output_hidden_states,
795
+ return_dict=return_dict,
796
+ )
797
+
798
+ hidden_states = outputs[0]
799
+
800
+ if not self.config.skip_out_norm:
801
+ hidden_states = self.norm(hidden_states)
802
+
803
+ if output_hidden_states:
804
+ all_hidden_states = outputs[1]
805
+
806
+ all_hidden_states[0] += (hidden_states,)
807
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
808
+ else:
809
+ outputs = (hidden_states,) + outputs[1:]
810
+
811
+ if not return_dict:
812
+ return tuple(v for v in outputs if v is not None)
813
+
814
+ return FlaxBaseModelOutput(
815
+ last_hidden_state=hidden_states,
816
+ hidden_states=outputs[1],
817
+ attentions=outputs[-1],
818
+ )
819
+
820
+
821
+ @add_start_docstrings(
822
+ "The bare Gemma3 Model transformer outputting raw hidden-states without any specific head on top.",
823
+ TPU_GEMMA3_START_DOCSTRING,
824
+ )
825
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModel with Llama->Gemma3
826
+ class FlaxTPUGemma3Model(FlaxTPUGemma3PreTrainedModel):
827
+ module_class = FlaxTPUGemma3Module
828
+
829
+
830
+ append_call_sample_docstring(
831
+ FlaxTPUGemma3Model,
832
+ _CHECKPOINT_FOR_DOC,
833
+ FlaxBaseModelOutput,
834
+ _CONFIG_FOR_DOC,
835
+ real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
836
+ )
837
+
838
+
839
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Gemma3
840
+ class FlaxTPUGemma3ForCausalLMModule(nn.Module):
841
+ config: TPUGemma3Config
842
+ dtype: jnp.dtype = jnp.float32
843
+ gradient_checkpointing: bool = False
844
+
845
+ def setup(self):
846
+ self.model = FlaxTPUGemma3Module(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
847
+ self.lm_head = nn.Dense(
848
+ self.config.vocab_size,
849
+ use_bias=False,
850
+ dtype=self.dtype,
851
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
852
+ )
853
+
854
+ def embed(self, input_ids):
855
+ return self.model.embed(input_ids)
856
+
857
+ # Ignore copy
858
+ def __call__(
859
+ self,
860
+ input_ids,
861
+ inputs_embeds=None,
862
+ attention_mask=None,
863
+ position_ids=None,
864
+ deterministic: bool = True,
865
+ init_cache: bool = False,
866
+ output_attentions: bool = False,
867
+ output_hidden_states: bool = False,
868
+ return_dict: bool = True,
869
+ ):
870
+ outputs = self.model(
871
+ input_ids,
872
+ inputs_embeds=inputs_embeds,
873
+ position_ids=position_ids,
874
+ attention_mask=attention_mask,
875
+ deterministic=deterministic,
876
+ init_cache=init_cache,
877
+ output_attentions=output_attentions,
878
+ output_hidden_states=output_hidden_states,
879
+ return_dict=return_dict,
880
+ )
881
+
882
+ hidden_states = outputs[0]
883
+ # should be skipped automatically in this case (since unused), but check if JIT actually does this
884
+ if not self.config.skip_out_norm:
885
+ if self.config.tie_word_embeddings:
886
+ shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T
887
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
888
+ else:
889
+ lm_logits = self.lm_head(hidden_states)
890
+
891
+ lm_logits = jax.lax.with_sharding_constraint(
892
+ lm_logits,
893
+ jax.sharding.NamedSharding(getattr(self.config, "mesh"), P("data", None, "model")),
894
+ )
895
+
896
+ if self.config.final_logit_softcapping is not None:
897
+ lm_logits = lm_logits / self.config.final_logit_softcapping
898
+ lm_logits = jnp.tanh(lm_logits)
899
+ lm_logits = lm_logits * self.config.final_logit_softcapping
900
+ else:
901
+ lm_logits = None
902
+
903
+ if not return_dict:
904
+ return (lm_logits,) + outputs[1:]
905
+
906
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
907
+
908
+
909
+ @add_start_docstrings(
910
+ """
911
+ The Gemma3 Model transformer with a language modeling head (linear layer) on top.
912
+ """,
913
+ TPU_GEMMA3_START_DOCSTRING,
914
+ )
915
+ # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Gemma3
916
+ class FlaxTPUGemma3ForCausalLM(FlaxTPUGemma3PreTrainedModel):
917
+ module_class = FlaxTPUGemma3ForCausalLMModule
918
+
919
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
920
+ # initializing the cache
921
+ batch_size, seq_length = input_ids.shape
922
+
923
+ past_key_values = self.init_cache(batch_size, max_length)
924
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
925
+ # But since Gemma3 uses a causal mask, those positions are masked anyways.
926
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
927
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
928
+ if attention_mask is not None:
929
+ position_ids = attention_mask.cumsum(axis=-1) - 1
930
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
931
+ else:
932
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
933
+
934
+ return {
935
+ "past_key_values": past_key_values,
936
+ "attention_mask": extended_attention_mask,
937
+ "position_ids": position_ids,
938
+ }
939
+
940
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
941
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
942
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
943
+ return model_kwargs
944
+
945
+
946
+ append_call_sample_docstring(
947
+ FlaxTPUGemma3ForCausalLM,
948
+ _CHECKPOINT_FOR_DOC,
949
+ FlaxCausalLMOutput,
950
+ _CONFIG_FOR_DOC,
951
+ real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
952
+ )