Spaces:
Running
Running
feat: restore weights on CPU
Browse files- src/dalle_mini/model/modeling.py +255 -3
- tools/train/train.py +10 -8
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -15,16 +15,30 @@
|
|
| 15 |
""" DalleBart model. """
|
| 16 |
|
| 17 |
import math
|
|
|
|
| 18 |
from functools import partial
|
| 19 |
-
from
|
|
|
|
| 20 |
|
| 21 |
import flax.linen as nn
|
| 22 |
import jax
|
| 23 |
import jax.numpy as jnp
|
|
|
|
| 24 |
from flax.core.frozen_dict import unfreeze
|
| 25 |
from flax.linen import make_causal_mask
|
| 26 |
-
from flax.
|
|
|
|
|
|
|
| 27 |
from jax.random import PRNGKey
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from transformers.modeling_flax_outputs import (
|
| 29 |
FlaxCausalLMOutputWithCrossAttentions,
|
| 30 |
FlaxSeq2SeqLMOutput,
|
|
@@ -300,7 +314,8 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 300 |
- added num_params property
|
| 301 |
- config_class replaced to DalleBartConfig
|
| 302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
| 303 |
-
- init weights on CPU
|
|
|
|
| 304 |
"""
|
| 305 |
|
| 306 |
config_class = DalleBartConfig
|
|
@@ -359,6 +374,243 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
| 359 |
).values()
|
| 360 |
return sum(list(num_params))
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
| 364 |
"""
|
|
|
|
| 15 |
""" DalleBart model. """
|
| 16 |
|
| 17 |
import math
|
| 18 |
+
import os
|
| 19 |
from functools import partial
|
| 20 |
+
from pickle import UnpicklingError
|
| 21 |
+
from typing import Optional, Tuple, Union
|
| 22 |
|
| 23 |
import flax.linen as nn
|
| 24 |
import jax
|
| 25 |
import jax.numpy as jnp
|
| 26 |
+
import msgpack.exceptions
|
| 27 |
from flax.core.frozen_dict import unfreeze
|
| 28 |
from flax.linen import make_causal_mask
|
| 29 |
+
from flax.serialization import from_bytes
|
| 30 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 31 |
+
from jax import lax
|
| 32 |
from jax.random import PRNGKey
|
| 33 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 34 |
+
from transformers.file_utils import (
|
| 35 |
+
FLAX_WEIGHTS_NAME,
|
| 36 |
+
WEIGHTS_NAME,
|
| 37 |
+
cached_path,
|
| 38 |
+
hf_bucket_url,
|
| 39 |
+
is_offline_mode,
|
| 40 |
+
is_remote_url,
|
| 41 |
+
)
|
| 42 |
from transformers.modeling_flax_outputs import (
|
| 43 |
FlaxCausalLMOutputWithCrossAttentions,
|
| 44 |
FlaxSeq2SeqLMOutput,
|
|
|
|
| 314 |
- added num_params property
|
| 315 |
- config_class replaced to DalleBartConfig
|
| 316 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
| 317 |
+
- init weights on CPU with `load_on_cpu`
|
| 318 |
+
- restore weights on CPU with custom `from_pretrained`
|
| 319 |
"""
|
| 320 |
|
| 321 |
config_class = DalleBartConfig
|
|
|
|
| 374 |
).values()
|
| 375 |
return sum(list(num_params))
|
| 376 |
|
| 377 |
+
@classmethod
|
| 378 |
+
def from_pretrained(
|
| 379 |
+
cls,
|
| 380 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 381 |
+
dtype: jnp.dtype = jnp.float32,
|
| 382 |
+
*model_args,
|
| 383 |
+
**kwargs,
|
| 384 |
+
):
|
| 385 |
+
config = kwargs.pop("config", None)
|
| 386 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 387 |
+
from_pt = kwargs.pop("from_pt", False)
|
| 388 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
| 389 |
+
force_download = kwargs.pop("force_download", False)
|
| 390 |
+
resume_download = kwargs.pop("resume_download", False)
|
| 391 |
+
proxies = kwargs.pop("proxies", None)
|
| 392 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 393 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 394 |
+
revision = kwargs.pop("revision", None)
|
| 395 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
| 396 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
| 397 |
+
|
| 398 |
+
user_agent = {
|
| 399 |
+
"file_type": "model",
|
| 400 |
+
"framework": "flax",
|
| 401 |
+
"from_auto_class": from_auto_class,
|
| 402 |
+
}
|
| 403 |
+
if from_pipeline is not None:
|
| 404 |
+
user_agent["using_pipeline"] = from_pipeline
|
| 405 |
+
|
| 406 |
+
if is_offline_mode() and not local_files_only:
|
| 407 |
+
logger.info("Offline mode: forcing local_files_only=True")
|
| 408 |
+
local_files_only = True
|
| 409 |
+
|
| 410 |
+
# Load config if we don't provide a configuration
|
| 411 |
+
if not isinstance(config, PretrainedConfig):
|
| 412 |
+
config_path = (
|
| 413 |
+
config if config is not None else pretrained_model_name_or_path
|
| 414 |
+
)
|
| 415 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
| 416 |
+
config_path,
|
| 417 |
+
cache_dir=cache_dir,
|
| 418 |
+
return_unused_kwargs=True,
|
| 419 |
+
force_download=force_download,
|
| 420 |
+
resume_download=resume_download,
|
| 421 |
+
proxies=proxies,
|
| 422 |
+
local_files_only=local_files_only,
|
| 423 |
+
use_auth_token=use_auth_token,
|
| 424 |
+
revision=revision,
|
| 425 |
+
_from_auto=from_auto_class,
|
| 426 |
+
_from_pipeline=from_pipeline,
|
| 427 |
+
**kwargs,
|
| 428 |
+
)
|
| 429 |
+
else:
|
| 430 |
+
model_kwargs = kwargs
|
| 431 |
+
|
| 432 |
+
# Add the dtype to model_kwargs
|
| 433 |
+
model_kwargs["dtype"] = dtype
|
| 434 |
+
|
| 435 |
+
# Load model
|
| 436 |
+
if pretrained_model_name_or_path is not None:
|
| 437 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 438 |
+
if from_pt and os.path.isfile(
|
| 439 |
+
os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
| 440 |
+
):
|
| 441 |
+
# Load from a PyTorch checkpoint
|
| 442 |
+
archive_file = os.path.join(
|
| 443 |
+
pretrained_model_name_or_path, WEIGHTS_NAME
|
| 444 |
+
)
|
| 445 |
+
elif os.path.isfile(
|
| 446 |
+
os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
| 447 |
+
):
|
| 448 |
+
# Load from a Flax checkpoint
|
| 449 |
+
archive_file = os.path.join(
|
| 450 |
+
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
|
| 451 |
+
)
|
| 452 |
+
else:
|
| 453 |
+
raise EnvironmentError(
|
| 454 |
+
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
|
| 455 |
+
f"{pretrained_model_name_or_path} or `from_pt` set to False"
|
| 456 |
+
)
|
| 457 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
|
| 458 |
+
pretrained_model_name_or_path
|
| 459 |
+
):
|
| 460 |
+
archive_file = pretrained_model_name_or_path
|
| 461 |
+
else:
|
| 462 |
+
archive_file = hf_bucket_url(
|
| 463 |
+
pretrained_model_name_or_path,
|
| 464 |
+
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
|
| 465 |
+
revision=revision,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# redirect to the cache, if necessary
|
| 469 |
+
try:
|
| 470 |
+
resolved_archive_file = cached_path(
|
| 471 |
+
archive_file,
|
| 472 |
+
cache_dir=cache_dir,
|
| 473 |
+
force_download=force_download,
|
| 474 |
+
proxies=proxies,
|
| 475 |
+
resume_download=resume_download,
|
| 476 |
+
local_files_only=local_files_only,
|
| 477 |
+
use_auth_token=use_auth_token,
|
| 478 |
+
user_agent=user_agent,
|
| 479 |
+
)
|
| 480 |
+
except EnvironmentError as err:
|
| 481 |
+
logger.error(err)
|
| 482 |
+
msg = (
|
| 483 |
+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
| 484 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
| 485 |
+
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
| 486 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
|
| 487 |
+
)
|
| 488 |
+
raise EnvironmentError(msg)
|
| 489 |
+
|
| 490 |
+
if resolved_archive_file == archive_file:
|
| 491 |
+
logger.info(f"loading weights file {archive_file}")
|
| 492 |
+
else:
|
| 493 |
+
logger.info(
|
| 494 |
+
f"loading weights file {archive_file} from cache at {resolved_archive_file}"
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
resolved_archive_file = None
|
| 498 |
+
|
| 499 |
+
# init random models
|
| 500 |
+
model = cls(config, *model_args, **model_kwargs)
|
| 501 |
+
|
| 502 |
+
with open(resolved_archive_file, "rb") as state_f:
|
| 503 |
+
try:
|
| 504 |
+
state = from_bytes(cls, state_f.read())
|
| 505 |
+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
| 506 |
+
try:
|
| 507 |
+
with open(resolved_archive_file) as f:
|
| 508 |
+
if f.read().startswith("version"):
|
| 509 |
+
raise OSError(
|
| 510 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
| 511 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
| 512 |
+
"you cloned."
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
raise ValueError from e
|
| 516 |
+
except (UnicodeDecodeError, ValueError):
|
| 517 |
+
raise EnvironmentError(
|
| 518 |
+
f"Unable to convert {archive_file} to Flax deserializable object. "
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# if model is base model only use model_prefix key
|
| 522 |
+
if (
|
| 523 |
+
cls.base_model_prefix not in dict(model.params)
|
| 524 |
+
and cls.base_model_prefix in state
|
| 525 |
+
):
|
| 526 |
+
state = state[cls.base_model_prefix]
|
| 527 |
+
|
| 528 |
+
# if model is head model and we are loading weights from base model
|
| 529 |
+
# we initialize new params dict with base_model_prefix
|
| 530 |
+
if (
|
| 531 |
+
cls.base_model_prefix in dict(model.params)
|
| 532 |
+
and cls.base_model_prefix not in state
|
| 533 |
+
):
|
| 534 |
+
state = {cls.base_model_prefix: state}
|
| 535 |
+
|
| 536 |
+
# flatten dicts
|
| 537 |
+
state = flatten_dict(state)
|
| 538 |
+
|
| 539 |
+
random_state = flatten_dict(unfreeze(model.params))
|
| 540 |
+
|
| 541 |
+
missing_keys = model.required_params - set(state.keys())
|
| 542 |
+
unexpected_keys = set(state.keys()) - model.required_params
|
| 543 |
+
|
| 544 |
+
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
| 545 |
+
# matching the weights in the model.
|
| 546 |
+
mismatched_keys = []
|
| 547 |
+
for key in state.keys():
|
| 548 |
+
if key in random_state and state[key].shape != random_state[key].shape:
|
| 549 |
+
if ignore_mismatched_sizes:
|
| 550 |
+
mismatched_keys.append(
|
| 551 |
+
(key, state[key].shape, random_state[key].shape)
|
| 552 |
+
)
|
| 553 |
+
state[key] = random_state[key]
|
| 554 |
+
else:
|
| 555 |
+
raise ValueError(
|
| 556 |
+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
| 557 |
+
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
|
| 558 |
+
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
|
| 559 |
+
"model."
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# add missing keys as random parameters
|
| 563 |
+
for missing_key in missing_keys:
|
| 564 |
+
state[missing_key] = random_state[missing_key]
|
| 565 |
+
|
| 566 |
+
# remove unexpected keys to not be saved again
|
| 567 |
+
for unexpected_key in unexpected_keys:
|
| 568 |
+
del state[unexpected_key]
|
| 569 |
+
|
| 570 |
+
if len(unexpected_keys) > 0:
|
| 571 |
+
logger.warning(
|
| 572 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
| 573 |
+
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
| 574 |
+
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
| 575 |
+
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
| 576 |
+
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
| 577 |
+
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
| 578 |
+
)
|
| 579 |
+
else:
|
| 580 |
+
logger.info(
|
| 581 |
+
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
if len(missing_keys) > 0:
|
| 585 |
+
logger.warning(
|
| 586 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
| 587 |
+
f"and are newly initialized: {missing_keys}\n"
|
| 588 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
| 589 |
+
)
|
| 590 |
+
elif len(mismatched_keys) == 0:
|
| 591 |
+
logger.info(
|
| 592 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
| 593 |
+
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
| 594 |
+
f"you can already use {model.__class__.__name__} for predictions without further training."
|
| 595 |
+
)
|
| 596 |
+
if len(mismatched_keys) > 0:
|
| 597 |
+
mismatched_warning = "\n".join(
|
| 598 |
+
[
|
| 599 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
| 600 |
+
for key, shape1, shape2 in mismatched_keys
|
| 601 |
+
]
|
| 602 |
+
)
|
| 603 |
+
logger.warning(
|
| 604 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
| 605 |
+
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
|
| 606 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# set correct parameters
|
| 610 |
+
model.params = unflatten_dict(state)
|
| 611 |
+
|
| 612 |
+
return model
|
| 613 |
+
|
| 614 |
|
| 615 |
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
| 616 |
"""
|
tools/train/train.py
CHANGED
|
@@ -249,6 +249,9 @@ class TrainingArguments:
|
|
| 249 |
"help": "Number of updates steps to accumulate before performing an update pass."
|
| 250 |
},
|
| 251 |
)
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
learning_rate: float = field(
|
| 254 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
|
@@ -515,10 +518,8 @@ def main():
|
|
| 515 |
load_on_cpu=True,
|
| 516 |
)
|
| 517 |
|
| 518 |
-
#
|
| 519 |
-
|
| 520 |
-
model_args.tokenizer_name, use_fast=True
|
| 521 |
-
)
|
| 522 |
|
| 523 |
# get PartitionSpec for model params (required to be a dict)
|
| 524 |
param_spec = set_partitions(model.params)
|
|
@@ -526,14 +527,15 @@ def main():
|
|
| 526 |
# convert params to frozen dict
|
| 527 |
model._params = freeze(model.params)
|
| 528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
# Preprocessing the datasets.
|
| 530 |
# We need to normalize and tokenize inputs and targets.
|
| 531 |
-
|
| 532 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
| 533 |
|
| 534 |
-
# no dropout (hardcoded)
|
| 535 |
-
model.config.dropout = 0.0
|
| 536 |
-
|
| 537 |
# Initialize our training
|
| 538 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
| 539 |
|
|
|
|
| 249 |
"help": "Number of updates steps to accumulate before performing an update pass."
|
| 250 |
},
|
| 251 |
)
|
| 252 |
+
gradient_checkpointing: bool = field(
|
| 253 |
+
default=False, metadata={"help": "Use gradient checkpointing."}
|
| 254 |
+
)
|
| 255 |
|
| 256 |
learning_rate: float = field(
|
| 257 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
|
|
|
| 518 |
load_on_cpu=True,
|
| 519 |
)
|
| 520 |
|
| 521 |
+
# update model config per training args
|
| 522 |
+
model.config.gradient_checkpointing = training_args.gradient_checkpointing
|
|
|
|
|
|
|
| 523 |
|
| 524 |
# get PartitionSpec for model params (required to be a dict)
|
| 525 |
param_spec = set_partitions(model.params)
|
|
|
|
| 527 |
# convert params to frozen dict
|
| 528 |
model._params = freeze(model.params)
|
| 529 |
|
| 530 |
+
# Load tokenizer
|
| 531 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 532 |
+
model_args.tokenizer_name, use_fast=True
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
# Preprocessing the datasets.
|
| 536 |
# We need to normalize and tokenize inputs and targets.
|
|
|
|
| 537 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
| 538 |
|
|
|
|
|
|
|
|
|
|
| 539 |
# Initialize our training
|
| 540 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
| 541 |
|