Spaces:
Running
Running
fix: update model name
Browse files- tools/train/train.py +1 -1
tools/train/train.py
CHANGED
|
@@ -398,7 +398,7 @@ def main():
|
|
| 398 |
artifact_dir = artifact.download()
|
| 399 |
|
| 400 |
# load model
|
| 401 |
-
model =
|
| 402 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 403 |
print(model.params)
|
| 404 |
|
|
|
|
| 398 |
artifact_dir = artifact.download()
|
| 399 |
|
| 400 |
# load model
|
| 401 |
+
model = DalleBart.from_pretrained(artifact_dir)
|
| 402 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 403 |
print(model.params)
|
| 404 |
|