Spaces:
Running
Running
| import jax | |
| import flax.linen as nn | |
| from transformers.models.bart.modeling_flax_bart import ( | |
| FlaxBartModule, | |
| FlaxBartForConditionalGenerationModule, | |
| FlaxBartForConditionalGeneration, | |
| FlaxBartEncoder, | |
| FlaxBartDecoder | |
| ) | |
| from transformers import BartConfig | |
| # Model hyperparameters, for convenience | |
| OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos | |
| OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos | |
| BOS_TOKEN_ID = 16384 | |
| BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large | |
| class CustomFlaxBartModule(FlaxBartModule): | |
| def setup(self): | |
| # check config is valid, otherwise set default values | |
| self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) | |
| self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH) | |
| # we keep shared to easily load pre-trained weights | |
| self.shared = nn.Embed( | |
| self.config.vocab_size, | |
| self.config.d_model, | |
| embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), | |
| dtype=self.dtype, | |
| ) | |
| # a separate embedding is used for the decoder | |
| self.decoder_embed = nn.Embed( | |
| self.config.vocab_size_output, | |
| self.config.d_model, | |
| embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), | |
| dtype=self.dtype, | |
| ) | |
| self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) | |
| # the decoder has a different config | |
| decoder_config = BartConfig(self.config.to_dict()) | |
| decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder | |
| decoder_config.vocab_size = self.config.vocab_size_output | |
| self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed) | |
| class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): | |
| def setup(self): | |
| # check config is valid, otherwise set default values | |
| self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE) | |
| self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype) | |
| self.lm_head = nn.Dense( | |
| self.config.vocab_size_output, | |
| use_bias=False, | |
| dtype=self.dtype, | |
| kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), | |
| ) | |
| self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)) | |
| class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration): | |
| module_class = CustomFlaxBartForConditionalGenerationModule | |