Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
1023afa
1
Parent(s):
f5dba1e
Override from_pretrained to support wandb artifacts.
Browse files
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -44,6 +44,7 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
| 44 |
FlaxBartPreTrainedModel,
|
| 45 |
)
|
| 46 |
from transformers.utils import logging
|
|
|
|
| 47 |
|
| 48 |
from .configuration import DalleBartConfig
|
| 49 |
|
|
@@ -561,3 +562,18 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
|
|
| 561 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
| 562 |
|
| 563 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
FlaxBartPreTrainedModel,
|
| 45 |
)
|
| 46 |
from transformers.utils import logging
|
| 47 |
+
import wandb
|
| 48 |
|
| 49 |
from .configuration import DalleBartConfig
|
| 50 |
|
|
|
|
| 562 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
| 563 |
|
| 564 |
return outputs
|
| 565 |
+
|
| 566 |
+
@classmethod
|
| 567 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 568 |
+
"""
|
| 569 |
+
Initializes from a wandb artifact, or delegates loading to the superclass.
|
| 570 |
+
"""
|
| 571 |
+
if ':' in pretrained_model_name_or_path:
|
| 572 |
+
# wandb artifact
|
| 573 |
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 574 |
+
|
| 575 |
+
# we download everything, including opt_state, so we can resume training if needed
|
| 576 |
+
# see also: #120
|
| 577 |
+
pretrained_model_name_or_path = artifact.download()
|
| 578 |
+
|
| 579 |
+
return super(DalleBart, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|