Spaces:
Running
Running
feat(train): use compilation cache
Browse files- tools/train/train.py +6 -0
tools/train/train.py
CHANGED
|
@@ -41,6 +41,7 @@ from flax.serialization import from_bytes, to_bytes
|
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import onehot
|
| 43 |
from jax.experimental import PartitionSpec, maps
|
|
|
|
| 44 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
| 45 |
from tqdm import tqdm
|
| 46 |
from transformers import HfArgumentParser
|
|
@@ -53,6 +54,11 @@ from dalle_mini.model import (
|
|
| 53 |
set_partitions,
|
| 54 |
)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
logger = logging.getLogger(__name__)
|
| 57 |
|
| 58 |
|
|
|
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import onehot
|
| 43 |
from jax.experimental import PartitionSpec, maps
|
| 44 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 45 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
| 46 |
from tqdm import tqdm
|
| 47 |
from transformers import HfArgumentParser
|
|
|
|
| 54 |
set_partitions,
|
| 55 |
)
|
| 56 |
|
| 57 |
+
cc.initialize_cache(
|
| 58 |
+
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
logger = logging.getLogger(__name__)
|
| 63 |
|
| 64 |
|