Commit
·
131da64
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +37 -0
- .gitmodules +15 -0
- Dockerfile +79 -0
- README.md +82 -0
- __builtins__.pyi +7 -0
- configs/config.yaml +451 -0
- configs/config_empty.yaml +8 -0
- configs/experiments/ar.yaml +10 -0
- configs/experiments/elm.yaml +15 -0
- configs/experiments/eval_model.yaml +21 -0
- configs/experiments/eval_text.yaml +26 -0
- configs/experiments/eval_text_only.yaml +30 -0
- configs/experiments/eval_unified.yaml +27 -0
- configs/experiments/fid_cc12m.yaml +22 -0
- configs/experiments/fid_datacomp1b.yaml +22 -0
- configs/experiments/fid_hf.yaml +25 -0
- configs/experiments/jan_cub.yaml +51 -0
- configs/experiments/large_maskdit_exp.yaml +7 -0
- configs/experiments/large_scale_high_res_interleaved_inference.yaml +51 -0
- configs/experiments/large_scale_train.yaml +151 -0
- configs/experiments/large_scale_train_high_res.yaml +39 -0
- configs/experiments/large_scale_train_high_res_inference.yaml +30 -0
- configs/experiments/large_scale_train_high_res_interleaved.yaml +105 -0
- configs/experiments/maskgit.yaml +6 -0
- configs/experiments/master_eval.yaml +49 -0
- configs/experiments/mscoco_fid.yaml +21 -0
- configs/experiments/paired_standalone_fid_eval.yaml +29 -0
- configs/experiments/small_scale_train.yaml +187 -0
- configs/experiments/small_scale_train_caching.yaml +186 -0
- configs/experiments/small_text_only.yaml +28 -0
- configs/experiments/standalone_fid_eval.yaml +18 -0
- configs/experiments/titok.yaml +8 -0
- configs/experiments/titok_sl256.yaml +7 -0
- configs/experiments/txt_only.yaml +21 -0
- configs/experiments/unified.yaml +23 -0
- configs/experiments/vq16.yaml +9 -0
- configs/experiments/vq16_1024.yaml +8 -0
- configs/experiments/vq16_magvit.yaml +9 -0
- configs/experiments/vq16_t2i.yaml +10 -0
- configs/experiments/webdataset.yaml +12 -0
- configs/experiments/zero_shot_eval.yaml +29 -0
- configs/lr_scheduler/constant_warmup.yaml +2 -0
- configs/lr_scheduler/constant_warmup_cosine_decay.yaml +3 -0
- configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
- configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml +4 -0
- configs/model/extra_large.yaml +10 -0
- configs/model/large.yaml +14 -0
- configs/model/medium.yaml +12 -0
- configs/model/small-ar.yaml +11 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
outputs/
|
| 3 |
+
ckpts/
|
| 4 |
+
vqgan/vqgan_pretrained/
|
| 5 |
+
vqgan/vqgan_taming_ckpt/
|
| 6 |
+
data/
|
| 7 |
+
models/datasets/.cache/
|
| 8 |
+
*.json
|
| 9 |
+
output/
|
| 10 |
+
tmp*
|
| 11 |
+
multirun/
|
| 12 |
+
.nfs*
|
| 13 |
+
lightning_logs/
|
| 14 |
+
static/
|
| 15 |
+
archive/
|
| 16 |
+
output_profile/
|
| 17 |
+
logs/
|
| 18 |
+
.history/
|
| 19 |
+
.cache/
|
| 20 |
+
output*/
|
| 21 |
+
*.out
|
| 22 |
+
*.parquet
|
| 23 |
+
wandb/
|
| 24 |
+
vqgan/
|
| 25 |
+
*.csv
|
| 26 |
+
.python-version
|
| 27 |
+
ft_cache/
|
| 28 |
+
alias.txt
|
| 29 |
+
env.sh
|
| 30 |
+
generated_image.png
|
| 31 |
+
Untitled-1.ipynb
|
| 32 |
+
*.log
|
| 33 |
+
demo/old
|
| 34 |
+
*.pem
|
| 35 |
+
.sesskey
|
| 36 |
+
icons.py
|
| 37 |
+
generated/
|
.gitmodules
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "third_party/LlamaGen"]
|
| 2 |
+
path = third_party/LlamaGen
|
| 3 |
+
url = https://github.com/alexanderswerdlow/LlamaGen.git
|
| 4 |
+
branch = wip_v1
|
| 5 |
+
[submodule "third_party/Lumina-mGPT"]
|
| 6 |
+
path = third_party/Lumina-mGPT
|
| 7 |
+
url = https://github.com/alexanderswerdlow/Lumina-mGPT.git
|
| 8 |
+
branch = non_causal
|
| 9 |
+
[submodule "third_party/Show-o"]
|
| 10 |
+
path = third_party/Show-o
|
| 11 |
+
url = https://github.com/showlab/Show-o.git
|
| 12 |
+
[submodule "third_party/1d-tokenizer"]
|
| 13 |
+
path = third_party/1d-tokenizer
|
| 14 |
+
url = https://github.com/bytedance/1d-tokenizer.git
|
| 15 |
+
branch = main
|
Dockerfile
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Base image with CUDA 12.6.3 and cuDNN
|
| 2 |
+
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
| 6 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 7 |
+
SYSTEM=spaces \
|
| 8 |
+
AM_I_IN_A_DOCKER_CONTAINER=Yes \
|
| 9 |
+
PYTHONPATH=/home/appuser/app \
|
| 10 |
+
HF_HOME=/home/appuser/.cache \
|
| 11 |
+
TORCH_HOME=/home/appuser/.cache \
|
| 12 |
+
TMP_DIR=/home/appuser/tmp \
|
| 13 |
+
TRANSFORMERS_CACHE=/home/appuser/.cache/transformers \
|
| 14 |
+
NVIDIA_VISIBLE_DEVICES=all \
|
| 15 |
+
NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
| 16 |
+
|
| 17 |
+
# Install system dependencies and set Python 3.10 as default
|
| 18 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
| 19 |
+
build-essential \
|
| 20 |
+
python3.10 \
|
| 21 |
+
python3.10-distutils \
|
| 22 |
+
python3-pip \
|
| 23 |
+
ffmpeg \
|
| 24 |
+
libsm6 \
|
| 25 |
+
libxext6 \
|
| 26 |
+
libgl1 \
|
| 27 |
+
git \
|
| 28 |
+
openssh-client \
|
| 29 |
+
&& ln -sf /usr/bin/python3.10 /usr/bin/python \
|
| 30 |
+
&& ln -sf /usr/bin/pip3 /usr/bin/pip \
|
| 31 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
| 32 |
+
|
| 33 |
+
# Install `uv`
|
| 34 |
+
RUN pip install --upgrade pip \
|
| 35 |
+
&& pip install uv
|
| 36 |
+
|
| 37 |
+
# Create a non-root user
|
| 38 |
+
RUN useradd -m -u 1000 appuser
|
| 39 |
+
|
| 40 |
+
# Set working directory
|
| 41 |
+
WORKDIR /home/appuser/app
|
| 42 |
+
|
| 43 |
+
# Copy dependency files and install dependencies
|
| 44 |
+
COPY --chown=appuser pyproject.toml uv.lock README.md ./
|
| 45 |
+
RUN mkdir -p -m 0600 ~/.ssh && ssh-keyscan github.com >> ~/.ssh/known_hosts
|
| 46 |
+
|
| 47 |
+
RUN --mount=type=ssh uv sync --no-group dev
|
| 48 |
+
RUN --mount=type=ssh uv sync --frozen --no-cache \
|
| 49 |
+
&& chown -R appuser:appuser /home/appuser/app/.venv \
|
| 50 |
+
&& rm -rf /root/.cache /home/appuser/.cache
|
| 51 |
+
|
| 52 |
+
# Ensure non-root user has write access to cache and tmp directories
|
| 53 |
+
RUN mkdir -p /home/appuser/.cache/transformers /home/appuser/tmp /home/appuser/.cache \
|
| 54 |
+
&& chown -R appuser:appuser /home/appuser/.cache /home/appuser/tmp/ /home/appuser/app/
|
| 55 |
+
|
| 56 |
+
RUN chmod -R 777 /tmp
|
| 57 |
+
|
| 58 |
+
# Copy application code
|
| 59 |
+
COPY --chown=appuser demo demo
|
| 60 |
+
COPY --chown=appuser unidisc unidisc
|
| 61 |
+
COPY --chown=appuser models models
|
| 62 |
+
COPY --chown=appuser configs configs
|
| 63 |
+
COPY --chown=appuser third_party third_party
|
| 64 |
+
COPY --chown=appuser ckpts ckpts
|
| 65 |
+
COPY --chown=appuser ./__* ./
|
| 66 |
+
COPY --chown=appuser ./*.py ./
|
| 67 |
+
COPY --chown=appuser ./archive/pytorch_model_fsdp.bin ./
|
| 68 |
+
|
| 69 |
+
# Switch to non-root user
|
| 70 |
+
USER appuser
|
| 71 |
+
|
| 72 |
+
# Expose port for Gradio
|
| 73 |
+
EXPOSE 5003
|
| 74 |
+
|
| 75 |
+
# Command to run the application
|
| 76 |
+
CMD ["bash", "demo/demo.sh"]
|
| 77 |
+
|
| 78 |
+
# DOCKER_BUILDKIT=1 docker build --ssh default --network=host -t unidisc .
|
| 79 |
+
# docker run --network=host -it -p 5003:5003 unidisc
|
README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<br>
|
| 3 |
+
<img src="docs/images/banner.webp" width="1000">
|
| 4 |
+
<h3>Unified Multimodal Discrete Diffusion</h3>
|
| 5 |
+
|
| 6 |
+
[Alexander Swerdlow](https://aswerdlow.com/)<sup>1*</sup>
|
| 7 |
+
[Mihir Prabhudesai](https://mihirp1998.github.io/)<sup>1*</sup>
|
| 8 |
+
[Siddharth Gandhi](hhttps://www.ssgandhi.com/)<sup>1</sup>
|
| 9 |
+
[Deepak Pathak](https://www.cs.cmu.edu/~dpathak/)<sup>1</sup>
|
| 10 |
+
[Katerina Fragkiadaki](https://www.cs.cmu.edu/~katef/)<sup>1</sup>
|
| 11 |
+
<br>
|
| 12 |
+
|
| 13 |
+
<sup>1</sup> Carnegie Mellon University
|
| 14 |
+
|
| 15 |
+
[](https://arxiv.org/pdf/0000.00000) [](https://unidisc.github.io/)
|
| 16 |
+
|
| 17 |
+
<!-- [](https://huggingface.co/spaces/todo) -->
|
| 18 |
+
|
| 19 |
+
</div>
|
| 20 |
+
|
| 21 |
+
## Hugging Face models and annotations
|
| 22 |
+
|
| 23 |
+
The UniDisc checkpoints are available on [Hugging Face](https://huggingface.co/unidisc):
|
| 24 |
+
* [unidisc/todo](https://huggingface.co/unidisc/todo)
|
| 25 |
+
|
| 26 |
+
## Getting Started
|
| 27 |
+
|
| 28 |
+
To install the dependencies, run:
|
| 29 |
+
```bash
|
| 30 |
+
git submodule update --init --recursive
|
| 31 |
+
uv sync --no-group dev
|
| 32 |
+
uv sync
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
For a more detailed installation guide, please refer to [INSTALL.md](docs/INSTALL.md).
|
| 36 |
+
|
| 37 |
+
## Training
|
| 38 |
+
|
| 39 |
+
See [TRAIN.md](docs/TRAIN.md) for details.
|
| 40 |
+
|
| 41 |
+
## Inference
|
| 42 |
+
|
| 43 |
+
<!-- Inference demo for **TODO**.
|
| 44 |
+
```
|
| 45 |
+
TODO
|
| 46 |
+
``` -->
|
| 47 |
+
<!-- <img src="docs/todo.png" width="1000"> -->
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
Interactive demo for **TODO**.
|
| 51 |
+
```
|
| 52 |
+
python demo/server.py
|
| 53 |
+
python demo/client_simple_fasthtml.py
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
## Training
|
| 58 |
+
|
| 59 |
+
See [TRAINING.md](docs/TRAINING.md) for details.
|
| 60 |
+
|
| 61 |
+
## Evaluation
|
| 62 |
+
|
| 63 |
+
See [EVAL.md](docs/EVAL.md) for details.
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
### Citation
|
| 67 |
+
To cite our work, please use the following:
|
| 68 |
+
```
|
| 69 |
+
@article{TODO,
|
| 70 |
+
title={TODO},
|
| 71 |
+
author={TODO},
|
| 72 |
+
journal={arXiv preprint arXiv:TODO},
|
| 73 |
+
year={TODO}
|
| 74 |
+
}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Credits
|
| 78 |
+
|
| 79 |
+
This repository is built on top of the following repositories:
|
| 80 |
+
|
| 81 |
+
- [MDLM](https://github.com/kuleshov-group/mdlm)
|
| 82 |
+
- [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X)
|
__builtins__.pyi
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ipdb import set_trace as st
|
| 2 |
+
from decoupled_utils import start_timing as start_timing
|
| 3 |
+
from decoupled_utils import end_timing as end_timing
|
| 4 |
+
ENABLE_TIMING: bool
|
| 5 |
+
ENABLE_TIMING_SYNC: bool
|
| 6 |
+
DEVICE_BACKEND_TYPE: str
|
| 7 |
+
exists = lambda v: v is not None
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- /model: small
|
| 4 |
+
- /noise: loglinear
|
| 5 |
+
- /lr_scheduler: constant_warmup
|
| 6 |
+
- /experiments: []
|
| 7 |
+
# - override hydra/launcher: submitit_slurm
|
| 8 |
+
|
| 9 |
+
slurm: False
|
| 10 |
+
debug: False
|
| 11 |
+
mode: train # train / eval
|
| 12 |
+
diffusion: absorbing_state
|
| 13 |
+
backbone: dit # dit / dimamba / ar
|
| 14 |
+
parameterization: subs # subs / d3pm / sedd
|
| 15 |
+
time_conditioning: False
|
| 16 |
+
T: 0 # 0 (continuous time) / 1000
|
| 17 |
+
subs_masking: False
|
| 18 |
+
seed: 42
|
| 19 |
+
profile: False
|
| 20 |
+
# These belong in trainer.* and hydra.launcher.* but are put here for CLI convinience
|
| 21 |
+
devices: ${device_count:}
|
| 22 |
+
nodes: 1
|
| 23 |
+
partition: ${find_partition:}
|
| 24 |
+
constraint: ${find_constraint:}
|
| 25 |
+
ckpt: null
|
| 26 |
+
|
| 27 |
+
loader:
|
| 28 |
+
desired_global_batch_size: 512
|
| 29 |
+
global_batch_size: null
|
| 30 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 31 |
+
batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 32 |
+
eval_batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 33 |
+
num_workers: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 16, 4)"}
|
| 34 |
+
pin_memory: True
|
| 35 |
+
persistent_workers: True
|
| 36 |
+
|
| 37 |
+
sampling:
|
| 38 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
| 39 |
+
steps: 1000
|
| 40 |
+
max_sampling_steps: 500 # The highest level we use for sampling
|
| 41 |
+
noise_removal: True
|
| 42 |
+
num_sample_log: 2
|
| 43 |
+
semi_ar: False
|
| 44 |
+
stride_length: 1
|
| 45 |
+
num_strides: 1
|
| 46 |
+
|
| 47 |
+
eval:
|
| 48 |
+
checkpoint_path: '' # Used to evaluate a checkpoint after training.
|
| 49 |
+
disable_ema: False
|
| 50 |
+
compute_generative_perplexity: False
|
| 51 |
+
perplexity_batch_size: 8
|
| 52 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 53 |
+
generate_samples: True
|
| 54 |
+
cfg: null
|
| 55 |
+
num_masking_viz_batches: 1
|
| 56 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 57 |
+
test_eval_speed: False
|
| 58 |
+
standalone_fid: False
|
| 59 |
+
visualize_data_only: false
|
| 60 |
+
val_with_train_data: false
|
| 61 |
+
max_num_fid_batches_per_device: null
|
| 62 |
+
class_conditional_fid: false
|
| 63 |
+
compute_entropy: false
|
| 64 |
+
compute_standalone_mauve: false
|
| 65 |
+
compute_standalone_entropy: false
|
| 66 |
+
compute_img_to_txt_mauve_clip: false
|
| 67 |
+
compute_img_to_txt_mauve_during_unconditional_fid: false
|
| 68 |
+
mauve_num_samples: 5000
|
| 69 |
+
mauve_divergence_curve_discretization_size: 25 # default in mauve repo
|
| 70 |
+
mauve_average_over_seeds: 3
|
| 71 |
+
mauve_scaling_factor: 5 # default in mauve repo
|
| 72 |
+
txt_conditional_fid: false
|
| 73 |
+
unconditional_fid: false
|
| 74 |
+
fid_mode: inline
|
| 75 |
+
calculate_clip_score: false
|
| 76 |
+
clean_fid_use_precomputed_stats: false
|
| 77 |
+
clean_fid_precomputed_name: null
|
| 78 |
+
clean_fid_precomputed_split: null
|
| 79 |
+
clean_fid_precomputed_res: null
|
| 80 |
+
attention_caching: false
|
| 81 |
+
set_random_gen_seed: false
|
| 82 |
+
compute_val_metrics_standalone: false
|
| 83 |
+
num_val_metrics_standalone_batches_per_device: ${eval:'max(${eval.num_val_metrics_standalone_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
|
| 84 |
+
num_val_metrics_standalone_samples: -1
|
| 85 |
+
return_unweighed_sim: false
|
| 86 |
+
compute_chameleon_perplexity: false
|
| 87 |
+
global_disable_mauve: false
|
| 88 |
+
bypass_normal_validation: false
|
| 89 |
+
auto_enhance: false
|
| 90 |
+
num_auto_enhance_iter: 2
|
| 91 |
+
ar_inpainting_min_val: 0.5
|
| 92 |
+
ar_inpainting_max_val: 1.0
|
| 93 |
+
ar_inpainting_force_val: null
|
| 94 |
+
|
| 95 |
+
optim:
|
| 96 |
+
weight_decay: 0
|
| 97 |
+
lr: 3e-4
|
| 98 |
+
beta1: 0.9
|
| 99 |
+
beta2: 0.999
|
| 100 |
+
eps: 1e-8
|
| 101 |
+
fused: true
|
| 102 |
+
|
| 103 |
+
model:
|
| 104 |
+
use_custom_vae_config: false
|
| 105 |
+
use_custom_vae_ckpt: null
|
| 106 |
+
downscale_ratio: null
|
| 107 |
+
image_vocab_size: null
|
| 108 |
+
vae_type: null
|
| 109 |
+
use_attention_mask: false
|
| 110 |
+
|
| 111 |
+
cond_use_custom_vae_config: false
|
| 112 |
+
cond_use_custom_vae_ckpt: null
|
| 113 |
+
cond_downscale_ratio: null
|
| 114 |
+
cond_image_vocab_size: null
|
| 115 |
+
cond_vae_type: null
|
| 116 |
+
text_model: true
|
| 117 |
+
|
| 118 |
+
attn_type: flash
|
| 119 |
+
force_varlen_attn: false
|
| 120 |
+
force_cast_bf16: false
|
| 121 |
+
norm_type: layernorm
|
| 122 |
+
mup: false
|
| 123 |
+
qk_norm: false
|
| 124 |
+
distillation: false
|
| 125 |
+
force_argmax_valid_indices: false
|
| 126 |
+
use_flash_attn_3: false
|
| 127 |
+
use_spda_attn: false # Spelled wrong...
|
| 128 |
+
rope_2d: false
|
| 129 |
+
modality_embed: false
|
| 130 |
+
zero_linear_init: true
|
| 131 |
+
full_attention: true
|
| 132 |
+
use_lora: false
|
| 133 |
+
use_kv_cache: false
|
| 134 |
+
force_optimized_native_attn: false
|
| 135 |
+
use_pretrained_img_emb: true
|
| 136 |
+
use_flex_attention: false
|
| 137 |
+
add_labels: null
|
| 138 |
+
flex_attention_txt_masking_prob: null
|
| 139 |
+
flex_attention_img_masking_prob: null
|
| 140 |
+
|
| 141 |
+
trainer:
|
| 142 |
+
_target_: lightning.Trainer
|
| 143 |
+
accelerator: cuda
|
| 144 |
+
num_nodes: ${nodes}
|
| 145 |
+
devices: ${devices}
|
| 146 |
+
|
| 147 |
+
# Given a desired global batch size (e.g., how many batches we see before a optim.step, summed over all nodes/gpus/accum_steps), we find the number of gradient accumulations that gets us closest given our current configuration. We assume that loader.batch_size is the largest that can fit in a single fwd/bwd.
|
| 148 |
+
accumulate_grad_batches: ${find_grad_accum:${loader.desired_global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 149 |
+
gradient_clip_val: 1.0
|
| 150 |
+
precision: 'bf16'
|
| 151 |
+
max_steps: 1_000_000_000
|
| 152 |
+
|
| 153 |
+
num_epochs: 1_000_000_000
|
| 154 |
+
optimizer_cls: adamw
|
| 155 |
+
set_grads_to_none: true
|
| 156 |
+
eval_on_start: true
|
| 157 |
+
eval_decay_steps: false
|
| 158 |
+
eval_epochs: null
|
| 159 |
+
ckpt_steps: 100000
|
| 160 |
+
fsdp: false
|
| 161 |
+
force_enable_checkpointing: false
|
| 162 |
+
limit_val_batches: null
|
| 163 |
+
ckpt_every_n_minutes: 60
|
| 164 |
+
ckpt_recent_timeout_minutes: 10
|
| 165 |
+
checkpoint_all_ranks: true
|
| 166 |
+
force_null_sigma: false
|
| 167 |
+
|
| 168 |
+
log_every_n_steps: 10
|
| 169 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 170 |
+
val_check_interval: 100
|
| 171 |
+
|
| 172 |
+
ema: 0.9999
|
| 173 |
+
antithetic_sampling: True
|
| 174 |
+
importance_sampling: False
|
| 175 |
+
sampling_eps: 1e-3
|
| 176 |
+
change_of_variables: False
|
| 177 |
+
benchmark: true
|
| 178 |
+
backward_pass: true
|
| 179 |
+
forward_pass: true
|
| 180 |
+
profile_memory: false
|
| 181 |
+
pytorch_profile: false
|
| 182 |
+
nvtx_profile: false
|
| 183 |
+
custom_ddp_bf16: true
|
| 184 |
+
log_seperate_modal_losses: true
|
| 185 |
+
use_gradient_checkpointing: false
|
| 186 |
+
text_loss_weight: null
|
| 187 |
+
img_loss_weight: null
|
| 188 |
+
disable_strict_load: false
|
| 189 |
+
attach_oom_observer_eval: false
|
| 190 |
+
find_unused_parameters: false
|
| 191 |
+
restart_on_failure: false
|
| 192 |
+
skip_early_checkpointing: true
|
| 193 |
+
log_flops: true
|
| 194 |
+
sync_timing: false
|
| 195 |
+
use_custom_ema: false
|
| 196 |
+
scale_lr_by_batch_size: false
|
| 197 |
+
tpu_eager: false
|
| 198 |
+
allow_dynamic_nodes: false
|
| 199 |
+
force_disable_signal_handler: false
|
| 200 |
+
tpu_profile: false
|
| 201 |
+
tpu_cache: false
|
| 202 |
+
enable_jax_smi: false
|
| 203 |
+
tpu_compile_debug: false
|
| 204 |
+
xla_spmd: false
|
| 205 |
+
log_grad_norm: true
|
| 206 |
+
tpu_profile_markers: true
|
| 207 |
+
compile: false
|
| 208 |
+
disable_all_checkpointing: false
|
| 209 |
+
tpu_force_mark_step: false
|
| 210 |
+
ar_shift: false
|
| 211 |
+
ar_llm_loss: false
|
| 212 |
+
ar_print_loss: false
|
| 213 |
+
chameleon_z_loss: null
|
| 214 |
+
image_mode: discrete # continuous / discrete
|
| 215 |
+
chameleon_use_ce_loss: false
|
| 216 |
+
low_precision_loss: false
|
| 217 |
+
low_precision_params: false
|
| 218 |
+
scratch: false
|
| 219 |
+
use_spmd_distributed_checkpointing: null
|
| 220 |
+
use_simple_spmd_distributed_checkpointing: false
|
| 221 |
+
load_from_state_dict: null
|
| 222 |
+
load_from_optimizer_state_dict: null
|
| 223 |
+
multimodal_batches: false
|
| 224 |
+
sync_dataloader_timing: false
|
| 225 |
+
compile_flag_pos_emb: false
|
| 226 |
+
compile_fullgraph: false
|
| 227 |
+
compile_mode: max-autotune-no-cudagraphs
|
| 228 |
+
joint_ar_nar_prob: null
|
| 229 |
+
joint_ar_nar_prob_warmup_steps: null
|
| 230 |
+
joint_ar_nar_timestep_warmup_steps: null
|
| 231 |
+
spmd_mesh: null
|
| 232 |
+
detect_anomaly: false
|
| 233 |
+
freeze_chameleon_embeddings: false
|
| 234 |
+
ckpt_model_only: false
|
| 235 |
+
use_orig_params: null
|
| 236 |
+
disable_adjust_num_warmup_steps: false
|
| 237 |
+
mask_entire_modality: null
|
| 238 |
+
iterate_dataloader_only: false
|
| 239 |
+
force_bf16_eval: false
|
| 240 |
+
disable_all_eval_generation: false
|
| 241 |
+
debug_xla_sept: false
|
| 242 |
+
ignore_text_in_unified: false
|
| 243 |
+
allow_null_sigma: false
|
| 244 |
+
disable_forward_autocast_during_eval: false
|
| 245 |
+
viz_images_only: false
|
| 246 |
+
add_label: false
|
| 247 |
+
first_token_dropout: null
|
| 248 |
+
disable_ddp_optimizer: false
|
| 249 |
+
rand_flip_ar_prob: null
|
| 250 |
+
rand_ar_modality_dropout: null
|
| 251 |
+
use_linear_warmup_cosine_annealing: false
|
| 252 |
+
no_ce_weighting: false
|
| 253 |
+
interleaved: false
|
| 254 |
+
interleaved_training_flex_attention: false
|
| 255 |
+
awr: false
|
| 256 |
+
ar_inpainting: false
|
| 257 |
+
|
| 258 |
+
wandb:
|
| 259 |
+
entity: grads
|
| 260 |
+
project: ${eval:'"unidisc-debug" if ${debug} else "unidisc"'}
|
| 261 |
+
resume: ${eval:'"allow" if ${slurm} else None'}
|
| 262 |
+
id: null
|
| 263 |
+
group: null
|
| 264 |
+
job_type: null
|
| 265 |
+
name: null
|
| 266 |
+
tags:
|
| 267 |
+
- ${data.train}
|
| 268 |
+
|
| 269 |
+
checkpointing_root_dir: ${oc.env:UNIDISC_CHECKPOINTING_ROOT_DIR,null}
|
| 270 |
+
root_output_dir: ${oc.env:UNIDISC_ROOT_OUTPUT_DIR,outputs}
|
| 271 |
+
python_orig: |
|
| 272 |
+
accelerate launch \
|
| 273 |
+
--num_machines $SLURM_NNODES \
|
| 274 |
+
--num_processes $NUM_PROCESSES \
|
| 275 |
+
--rdzv_backend c10d \
|
| 276 |
+
--main_process_ip $MASTER_ADDR \
|
| 277 |
+
--main_process_port $MASTER_PORT \
|
| 278 |
+
--machine_rank $SLURM_PROCID \
|
| 279 |
+
--mixed_precision bf16 \
|
| 280 |
+
--dynamo_backend no \
|
| 281 |
+
--enable_cpu_affinity \
|
| 282 |
+
--max_restarts 0 \
|
| 283 |
+
|
| 284 |
+
mem_per_gpu: 40
|
| 285 |
+
cpus_per_gpu: 8
|
| 286 |
+
slurm_name: null
|
| 287 |
+
timeout_min: ${partition_limit:${partition}}
|
| 288 |
+
hydra:
|
| 289 |
+
run:
|
| 290 |
+
dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
|
| 291 |
+
sweep:
|
| 292 |
+
dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
|
| 293 |
+
subdir: ${hydra.job.id}
|
| 294 |
+
job:
|
| 295 |
+
chdir: true
|
| 296 |
+
# launcher:
|
| 297 |
+
# name: ${get_slurm_name:}
|
| 298 |
+
# # See https://hydra.cc/docs/configure_hydra/workdir/
|
| 299 |
+
# submitit_folder: ${hydra.sweep.dir}/%j
|
| 300 |
+
# nodes: ${nodes} # Number of nodes. This value is *per* node
|
| 301 |
+
# mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
|
| 302 |
+
# gpus_per_node: ${trainer.devices}
|
| 303 |
+
# partition: ${partition}
|
| 304 |
+
# constraint: ${constraint}
|
| 305 |
+
# exclude: ${exclude_nodes:}
|
| 306 |
+
|
| 307 |
+
# timeout_min: ${timeout_min}
|
| 308 |
+
# max_num_timeout: 12 # Num requeue exlcuding pre-emptions
|
| 309 |
+
# comment: aswerdlo
|
| 310 |
+
# stderr_to_stdout: true
|
| 311 |
+
|
| 312 |
+
# # Be careful with changing anything below.
|
| 313 |
+
# # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
|
| 314 |
+
# # see: https://github.com/huggingface/accelerate/issues/1918
|
| 315 |
+
|
| 316 |
+
# # The accelerate launcher w/1 initial process and then spawn 1 per GPU
|
| 317 |
+
# tasks_per_node: 1
|
| 318 |
+
# cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
|
| 319 |
+
# python: |
|
| 320 |
+
# bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
|
| 321 |
+
|
| 322 |
+
# # python: "${getpythoncmd:}"
|
| 323 |
+
# # tasks_per_node: ${devices}
|
| 324 |
+
# # cpus_per_task: 8
|
| 325 |
+
# # python: 'python'
|
| 326 |
+
|
| 327 |
+
# python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
|
| 328 |
+
# signal: 'B:USR2@360'
|
| 329 |
+
# post_srun_commands:
|
| 330 |
+
# - ''
|
| 331 |
+
# - wait
|
| 332 |
+
|
| 333 |
+
# srun_args:
|
| 334 |
+
# - '--jobid $SLURM_JOB_ID'
|
| 335 |
+
|
| 336 |
+
# setup:
|
| 337 |
+
# - |
|
| 338 |
+
# export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
| 339 |
+
# export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
|
| 340 |
+
# export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
|
| 341 |
+
# export NCCL_DEBUG=INFO
|
| 342 |
+
# export NCCL_NSOCKS_PERTHREAD=4
|
| 343 |
+
# export NCCL_SOCKET_NTHREADS=2
|
| 344 |
+
# export OMP_NUM_THREADS=2
|
| 345 |
+
# export PYTHONUNBUFFERED=1
|
| 346 |
+
# export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
|
| 347 |
+
# export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
|
| 348 |
+
# export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
|
| 349 |
+
# if [ -n "$SLURM_RESTART_COUNT" ]; then
|
| 350 |
+
# export RESTART_COUNT=$SLURM_RESTART_COUNT
|
| 351 |
+
# else
|
| 352 |
+
# export RESTART_COUNT=0
|
| 353 |
+
# fi
|
| 354 |
+
# export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
|
| 355 |
+
|
| 356 |
+
# mkdir -p $LOCAL_JOB_FOLDER
|
| 357 |
+
# printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
|
| 358 |
+
|
| 359 |
+
# echo "ibstatus: $(ibstatus)"
|
| 360 |
+
# echo "ibdev2netdev: $(ibdev2netdev)"
|
| 361 |
+
# echo "rdma device: $(rdma link)"
|
| 362 |
+
# echo "environment: $(env | grep NCCL)"
|
| 363 |
+
# echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
|
| 364 |
+
# echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
|
| 365 |
+
# echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
|
| 366 |
+
|
| 367 |
+
# trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
|
| 368 |
+
# if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
|
| 369 |
+
# if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
|
| 370 |
+
# # ps auxww | grep $USER; \
|
| 371 |
+
# pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
|
| 372 |
+
# echo "Found parent PIDs: $pid"; \
|
| 373 |
+
# for p in $pid; do \
|
| 374 |
+
# echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
|
| 375 |
+
# children=$(pgrep -P $p); \
|
| 376 |
+
# echo "Children: $children"; \
|
| 377 |
+
# if [ -n "$children" ]; then \
|
| 378 |
+
# for child in $children; do \
|
| 379 |
+
# ppid=$(ps -o ppid= -p $child | tr -d " ")
|
| 380 |
+
# if [ "$ppid" -eq "$p" ]; then
|
| 381 |
+
# echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
|
| 382 |
+
# kill -USR2 $child &
|
| 383 |
+
# else
|
| 384 |
+
# echo "Skipping non-direct child process: PID $child with PPID $ppid"
|
| 385 |
+
# fi
|
| 386 |
+
# done; \
|
| 387 |
+
# echo "Sent kill signals to children of $p"; \
|
| 388 |
+
# else \
|
| 389 |
+
# echo "No children found for $p"; \
|
| 390 |
+
# fi; \
|
| 391 |
+
# done; \
|
| 392 |
+
# wait;' SIGUSR2
|
| 393 |
+
|
| 394 |
+
checkpointing:
|
| 395 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 396 |
+
save_dir: ${cwd:}/checkpoints
|
| 397 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 398 |
+
resume_from_ckpt: true
|
| 399 |
+
resume_ckpt_path: ${cwd:}/checkpoints
|
| 400 |
+
initial_resume_ckpt_path: null
|
| 401 |
+
resume_wandb: true
|
| 402 |
+
checkpoints_total_limit: 2
|
| 403 |
+
use_automatic_naming: false
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
data:
|
| 407 |
+
cache_dir: ${oc.env:HF_DATASETS_CACHE,/grogu/user/mprabhud/aswerdlo/huggingface/datasets}
|
| 408 |
+
num_proc: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 4, 16)"}
|
| 409 |
+
cond_resolution: null
|
| 410 |
+
iterable: false
|
| 411 |
+
force_disable_shuffle: false
|
| 412 |
+
pin_dataset_to_gpu: false
|
| 413 |
+
webdataset_iterable: false
|
| 414 |
+
webdataset_train_data: null
|
| 415 |
+
webdataset_val_data: null
|
| 416 |
+
webdataset_train_num_samples: null
|
| 417 |
+
webdataset_val_num_samples: null
|
| 418 |
+
webdataset_indexed: false
|
| 419 |
+
dataset_type: null
|
| 420 |
+
keep_tensordict_on_disk: false
|
| 421 |
+
use_token_dataset: false
|
| 422 |
+
use_custom_tensordict_collate: false
|
| 423 |
+
use_weighted_tensordict_sampler: false
|
| 424 |
+
enable_cuda_in_tensordict_collate: true
|
| 425 |
+
data_dir_train: null
|
| 426 |
+
data_dir_val: null
|
| 427 |
+
token_output_dir: null
|
| 428 |
+
wrap_dataloaders: true
|
| 429 |
+
force_shuffle_train: false
|
| 430 |
+
move_tensordict_to_shm: false
|
| 431 |
+
keep_hf_dataset_in_memory: false
|
| 432 |
+
use_chameleon: false
|
| 433 |
+
tokenize_vqvae_in_dataloader: false
|
| 434 |
+
force_mp_spawn: false
|
| 435 |
+
force_raw_images_in_multiple_tensordict: false
|
| 436 |
+
disable_text_modality: false
|
| 437 |
+
txt_only: false
|
| 438 |
+
disable_mask_after_eos: false
|
| 439 |
+
allow_label: false
|
| 440 |
+
split_dataset: false
|
| 441 |
+
img_token_shift: ${model.text_vocab_size}
|
| 442 |
+
zero_shot_eval_dataset: null
|
| 443 |
+
require_sample_ids: false
|
| 444 |
+
use_packing_collate: false
|
| 445 |
+
dynamic_packing_lengths: false
|
| 446 |
+
remove_txt_img_padding: false
|
| 447 |
+
add_image_gen_tokens: false
|
| 448 |
+
use_slow_tokenizer: false
|
| 449 |
+
add_image_token: false
|
| 450 |
+
|
| 451 |
+
dummyarg: null
|
configs/config_empty.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- /model: small
|
| 4 |
+
- /experiments: []
|
| 5 |
+
|
| 6 |
+
# from omegaconf import OmegaConf
|
| 7 |
+
# with open("config.yaml", "w") as fp:
|
| 8 |
+
# OmegaConf.save(config=config, f=fp.name)
|
configs/experiments/ar.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
parameterization: ar
|
| 4 |
+
|
| 5 |
+
trainer:
|
| 6 |
+
ar_shift: true
|
| 7 |
+
|
| 8 |
+
model:
|
| 9 |
+
full_attention: false
|
| 10 |
+
use_flex_attention: false
|
configs/experiments/elm.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
backbone: elm
|
| 4 |
+
|
| 5 |
+
data:
|
| 6 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
| 7 |
+
|
| 8 |
+
model:
|
| 9 |
+
use_lora: false
|
| 10 |
+
full_attention: true
|
| 11 |
+
model_id: apple/OpenELM-270M # apple/OpenELM-1_1B
|
| 12 |
+
|
| 13 |
+
trainer:
|
| 14 |
+
use_gradient_checkpointing: false
|
| 15 |
+
sd3_compile_config: false
|
configs/experiments/eval_model.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: eval
|
| 4 |
+
|
| 5 |
+
loader:
|
| 6 |
+
batch_size: 16
|
| 7 |
+
eval_batch_size: 16
|
| 8 |
+
|
| 9 |
+
trainer:
|
| 10 |
+
disable_all_eval_generation: false
|
| 11 |
+
|
| 12 |
+
eval:
|
| 13 |
+
compute_generative_perplexity: true
|
| 14 |
+
generate_samples: true
|
| 15 |
+
num_sample_batches: 20
|
| 16 |
+
log_every_n_fid: 1
|
| 17 |
+
log_every_n_evals: 1
|
| 18 |
+
compute_standalone_mauve: true
|
| 19 |
+
mauve_num_samples: 5000
|
| 20 |
+
# mauve_divergence_curve_discretization_size: 200 # works well for our repo
|
| 21 |
+
# mauve_scaling_factor: 2 # works well for our repo
|
configs/experiments/eval_text.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: eval
|
| 4 |
+
|
| 5 |
+
sampling:
|
| 6 |
+
steps: 100
|
| 7 |
+
max_sampling_steps: 100
|
| 8 |
+
|
| 9 |
+
loader:
|
| 10 |
+
batch_size: 2
|
| 11 |
+
eval_batch_size: 2
|
| 12 |
+
|
| 13 |
+
trainer:
|
| 14 |
+
fsdp: false
|
| 15 |
+
|
| 16 |
+
eval:
|
| 17 |
+
perplexity_batch_size: 2
|
| 18 |
+
num_masking_viz_batches: 2
|
| 19 |
+
log_every_n_evals: 1
|
| 20 |
+
num_uncond_sample_batches: 2
|
| 21 |
+
num_sample_batches: 2
|
| 22 |
+
num_random_masking: 1
|
| 23 |
+
masking_batch_size: 2
|
| 24 |
+
cfg: null
|
| 25 |
+
generate_samples: true
|
| 26 |
+
compute_generative_perplexity: false
|
configs/experiments/eval_text_only.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: eval
|
| 4 |
+
debug: true
|
| 5 |
+
|
| 6 |
+
sampling:
|
| 7 |
+
steps: 100
|
| 8 |
+
max_sampling_steps: 100
|
| 9 |
+
|
| 10 |
+
loader:
|
| 11 |
+
batch_size: 2
|
| 12 |
+
eval_batch_size: 2
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
fsdp: false
|
| 16 |
+
|
| 17 |
+
model:
|
| 18 |
+
image_model_fid_eval: false
|
| 19 |
+
|
| 20 |
+
eval:
|
| 21 |
+
log_every_n_evals: 1
|
| 22 |
+
perplexity_batch_size: 2
|
| 23 |
+
num_uncond_sample_batches: 2
|
| 24 |
+
num_sample_batches: 2
|
| 25 |
+
num_masking_viz_batches: -1
|
| 26 |
+
num_random_masking: -1
|
| 27 |
+
masking_batch_size: -1
|
| 28 |
+
cfg: null
|
| 29 |
+
generate_samples: true
|
| 30 |
+
compute_generative_perplexity: true
|
configs/experiments/eval_unified.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: eval
|
| 4 |
+
devices: ${device_count:}
|
| 5 |
+
|
| 6 |
+
sampling:
|
| 7 |
+
steps: 500
|
| 8 |
+
max_sampling_steps: 1000
|
| 9 |
+
|
| 10 |
+
loader:
|
| 11 |
+
batch_size: 6
|
| 12 |
+
eval_batch_size: 6
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
fsdp: false
|
| 16 |
+
disable_all_eval_generation: false
|
| 17 |
+
|
| 18 |
+
eval:
|
| 19 |
+
perplexity_batch_size: 6
|
| 20 |
+
num_masking_viz_batches: 12
|
| 21 |
+
log_every_n_evals: 1
|
| 22 |
+
num_uncond_sample_batches: 5
|
| 23 |
+
num_sample_batches: 2
|
| 24 |
+
num_random_masking: 3
|
| 25 |
+
masking_batch_size: 6
|
| 26 |
+
cfg: 6.0
|
| 27 |
+
generate_samples: false
|
configs/experiments/fid_cc12m.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
keep_hf_dataset_in_memory: true
|
| 5 |
+
aggressive_aug: false
|
| 6 |
+
n_duplicate_train: null
|
| 7 |
+
n_duplicate_val: null
|
| 8 |
+
|
| 9 |
+
tokenize_vqvae_in_dataloader: false
|
| 10 |
+
enable_cuda_in_tensordict_collate: false
|
| 11 |
+
force_mp_spawn: false
|
| 12 |
+
keep_tensordict_on_disk: false
|
| 13 |
+
move_tensordict_to_shm: false
|
| 14 |
+
|
| 15 |
+
fid_dataset: cc12m_tokens_val_256
|
| 16 |
+
image_data_train: null
|
| 17 |
+
image_data_val: null
|
| 18 |
+
data_dir_train: ${data.data_dir_val}
|
| 19 |
+
data_dir_val:
|
| 20 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
|
| 21 |
+
weight: 1
|
| 22 |
+
name: ${data.fid_dataset}
|
configs/experiments/fid_datacomp1b.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
keep_hf_dataset_in_memory: true
|
| 5 |
+
aggressive_aug: false
|
| 6 |
+
n_duplicate_train: null
|
| 7 |
+
n_duplicate_val: null
|
| 8 |
+
|
| 9 |
+
tokenize_vqvae_in_dataloader: false
|
| 10 |
+
enable_cuda_in_tensordict_collate: false
|
| 11 |
+
force_mp_spawn: false
|
| 12 |
+
keep_tensordict_on_disk: false
|
| 13 |
+
move_tensordict_to_shm: false
|
| 14 |
+
|
| 15 |
+
fid_dataset: datacomp1b_8_magvit_val
|
| 16 |
+
image_data_train: null
|
| 17 |
+
image_data_val: null
|
| 18 |
+
data_dir_train: ${data.data_dir_val}
|
| 19 |
+
data_dir_val:
|
| 20 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
|
| 21 |
+
weight: -1
|
| 22 |
+
name: ${data.fid_dataset}
|
configs/experiments/fid_hf.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
disable_text_modality: false
|
| 5 |
+
keep_hf_dataset_in_memory: true
|
| 6 |
+
aggressive_aug: false
|
| 7 |
+
n_duplicate_train: null
|
| 8 |
+
n_duplicate_val: null
|
| 9 |
+
data_dir_train: []
|
| 10 |
+
data_dir_val: []
|
| 11 |
+
fid_dataset: sayakpaul/coco-30-val-2014
|
| 12 |
+
train: combined_tokens
|
| 13 |
+
val: {.train}
|
| 14 |
+
image_data_val:
|
| 15 |
+
- val: ${data.fid_dataset}
|
| 16 |
+
weight: -1
|
| 17 |
+
name: ${.val}
|
| 18 |
+
tokenize_vqvae_in_dataloader: false
|
| 19 |
+
raw_images: true
|
| 20 |
+
image_data_train:
|
| 21 |
+
- train: ${data.fid_dataset}
|
| 22 |
+
weight: -1
|
| 23 |
+
name: ${.train}
|
| 24 |
+
tokenize_vqvae_in_dataloader: false
|
| 25 |
+
raw_images: true
|
configs/experiments/jan_cub.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: medium
|
| 5 |
+
- override /lr_scheduler: cosine_with_hard_restarts_schedule_with_warmup
|
| 6 |
+
|
| 7 |
+
loader:
|
| 8 |
+
batch_size: 16
|
| 9 |
+
eval_batch_size: 16
|
| 10 |
+
desired_global_batch_size: 128
|
| 11 |
+
num_workers: 4
|
| 12 |
+
|
| 13 |
+
trainer:
|
| 14 |
+
ckpt_steps: 5000
|
| 15 |
+
val_check_interval: 100
|
| 16 |
+
use_legacy_update_batch_fn: true
|
| 17 |
+
mask_txt_only: true
|
| 18 |
+
mask_entire_modality: 0.15
|
| 19 |
+
ema: 0.9999
|
| 20 |
+
use_custom_ema: true
|
| 21 |
+
force_enable_checkpointing: true
|
| 22 |
+
skip_early_checkpointing: false
|
| 23 |
+
force_after_eos_padding: false
|
| 24 |
+
|
| 25 |
+
checkpointing:
|
| 26 |
+
checkpoints_total_limit: 20
|
| 27 |
+
|
| 28 |
+
lr_scheduler:
|
| 29 |
+
num_warmup_steps: 10000
|
| 30 |
+
num_training_steps: 400000
|
| 31 |
+
num_cycles: 80
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
resolution: 256
|
| 35 |
+
train: cub2011_custom
|
| 36 |
+
use_weighted_tensordict_sampler: false
|
| 37 |
+
|
| 38 |
+
model:
|
| 39 |
+
vae_type: titok128
|
| 40 |
+
txt_length: 18
|
| 41 |
+
img_length: 128
|
| 42 |
+
rope_2d: false
|
| 43 |
+
force_text_vocab_size: 5450
|
| 44 |
+
text_vocab_size: 5451
|
| 45 |
+
image_vocab_size: 8192
|
| 46 |
+
attn_dropout: 0.1
|
| 47 |
+
|
| 48 |
+
optim:
|
| 49 |
+
lr: 1.0e-04
|
| 50 |
+
weight_decay: 0.2
|
| 51 |
+
beta2: 0.99
|
configs/experiments/large_maskdit_exp.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: large_maskdit
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
backbone: maskdit
|
configs/experiments/large_scale_high_res_interleaved_inference.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
debug: true
|
| 4 |
+
seed: 163
|
| 5 |
+
|
| 6 |
+
loader:
|
| 7 |
+
eval_batch_size: 1
|
| 8 |
+
batch_size: 1
|
| 9 |
+
|
| 10 |
+
data:
|
| 11 |
+
move_tensordict_to_shm: false
|
| 12 |
+
resolution: 1024
|
| 13 |
+
disable_mask_after_eos: true
|
| 14 |
+
disable_packing: true
|
| 15 |
+
data_dir_val:
|
| 16 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
| 17 |
+
weight: 1.0
|
| 18 |
+
name: HPDv2_image_reward_512
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
img_length: 4096
|
| 22 |
+
txt_length: 1024
|
| 23 |
+
length: 5120
|
| 24 |
+
|
| 25 |
+
trainer:
|
| 26 |
+
compile: false
|
| 27 |
+
limit_val_batches: 2
|
| 28 |
+
fsdp: false
|
| 29 |
+
force_full_attention_mask: true
|
| 30 |
+
force_null_sigma: true
|
| 31 |
+
allow_null_sigma: true
|
| 32 |
+
|
| 33 |
+
eval:
|
| 34 |
+
num_sample_batches: 1
|
| 35 |
+
num_random_masking: 0
|
| 36 |
+
num_masking_viz_batches: 0
|
| 37 |
+
limit_val_batches_manual: 1
|
| 38 |
+
num_uncond_sample_batches: 10
|
| 39 |
+
eval_large_batch: 10
|
| 40 |
+
val_with_train_data: false
|
| 41 |
+
maskgit_r_temp: 4.5
|
| 42 |
+
half_uncond: false
|
| 43 |
+
cfg: 3.0
|
| 44 |
+
return_interleaved_modalities_split: true
|
| 45 |
+
static_img_txt_demo: true
|
| 46 |
+
visualize_sample: true
|
| 47 |
+
|
| 48 |
+
sampling:
|
| 49 |
+
steps: 50
|
| 50 |
+
max_sampling_steps: 50
|
| 51 |
+
predictor: "maskgit"
|
configs/experiments/large_scale_train.yaml
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- vq16_t2i
|
| 5 |
+
- override /model: extra_large
|
| 6 |
+
|
| 7 |
+
data:
|
| 8 |
+
train: combined_tokens
|
| 9 |
+
valid: ${.train}
|
| 10 |
+
precache: false
|
| 11 |
+
streaming: false
|
| 12 |
+
resolution: 256
|
| 13 |
+
block_size: 128
|
| 14 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
| 15 |
+
wrap: true
|
| 16 |
+
iterable: false
|
| 17 |
+
webdataset_iterable: false
|
| 18 |
+
webdataset_indexed: false
|
| 19 |
+
unpaired: false
|
| 20 |
+
dataset_type: null
|
| 21 |
+
tokens_flip_collate: false
|
| 22 |
+
n_val_samples: null
|
| 23 |
+
n_train_samples: null
|
| 24 |
+
n_duplicate_train: null
|
| 25 |
+
n_duplicate_val: null
|
| 26 |
+
raw_data_dir: null
|
| 27 |
+
save_train_dataloader: true
|
| 28 |
+
save_validation_dataloader: true
|
| 29 |
+
tokenizers_parallelism: false
|
| 30 |
+
token_data_dir: null
|
| 31 |
+
force_disable_shuffle: false
|
| 32 |
+
use_custom_tensordict_collate: true
|
| 33 |
+
use_weighted_tensordict_sampler: true
|
| 34 |
+
force_mp_spawn: false
|
| 35 |
+
enable_cuda_in_tensordict_collate: false
|
| 36 |
+
use_token_dataset: true
|
| 37 |
+
keep_tensordict_on_disk: true
|
| 38 |
+
move_tensordict_to_shm: false
|
| 39 |
+
add_text_to_weighted_sampler: false
|
| 40 |
+
data_dir_train:
|
| 41 |
+
# - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
| 42 |
+
# weight: 15.0
|
| 43 |
+
# name: hpdv2
|
| 44 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
|
| 45 |
+
weight: 1.0
|
| 46 |
+
name: pixelprose
|
| 47 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/journeydb_train
|
| 48 |
+
weight: 10.0
|
| 49 |
+
name: journeydb_train
|
| 50 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
|
| 51 |
+
weight: 1.0
|
| 52 |
+
name: datacomp0
|
| 53 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
|
| 54 |
+
weight: 1.0
|
| 55 |
+
name: datacomp1
|
| 56 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
|
| 57 |
+
weight: 1.0
|
| 58 |
+
name: datacomp2
|
| 59 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_3_tokens
|
| 60 |
+
weight: 1.0
|
| 61 |
+
name: datacomp3
|
| 62 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
|
| 63 |
+
weight: 1.0
|
| 64 |
+
name: datacomp4
|
| 65 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
|
| 66 |
+
weight: 1.0
|
| 67 |
+
name: datacomp5
|
| 68 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_6_tokens
|
| 69 |
+
weight: 1.0
|
| 70 |
+
name: datacomp6
|
| 71 |
+
data_dir_val:
|
| 72 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
|
| 73 |
+
weight: 1.0
|
| 74 |
+
name: dummy_1
|
| 75 |
+
|
| 76 |
+
model:
|
| 77 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
| 78 |
+
txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
|
| 79 |
+
length: ${eval:'${.txt_length} + ${.img_length}'}
|
| 80 |
+
unified_model: true
|
| 81 |
+
image_model: true
|
| 82 |
+
text_model: true
|
| 83 |
+
image_model_fid_eval: false
|
| 84 |
+
force_argmax_valid_indices: true
|
| 85 |
+
use_pretrained_img_emb: false
|
| 86 |
+
rope_2d: true
|
| 87 |
+
modality_embed: true
|
| 88 |
+
norm_type: rms
|
| 89 |
+
qk_norm: true
|
| 90 |
+
sandwich_normalization: true
|
| 91 |
+
text_vocab_size: 32001
|
| 92 |
+
|
| 93 |
+
loader:
|
| 94 |
+
batch_size: 8
|
| 95 |
+
eval_batch_size: ${eval:'${.batch_size} // 2'}
|
| 96 |
+
desired_global_batch_size: 512
|
| 97 |
+
persistent_workers: true
|
| 98 |
+
pin_memory: false
|
| 99 |
+
num_workers: 0
|
| 100 |
+
num_eval_workers: 0
|
| 101 |
+
eval:
|
| 102 |
+
log_every_n_evals: -1
|
| 103 |
+
log_every_n_fid: -1
|
| 104 |
+
limit_val_batches_manual: 16
|
| 105 |
+
generate_samples: true
|
| 106 |
+
compute_generative_perplexity: false
|
| 107 |
+
perplexity_batch_size: ${loader.eval_batch_size}
|
| 108 |
+
cfg: 5.0
|
| 109 |
+
num_val_metrics_standalone_samples: -1
|
| 110 |
+
num_val_metrics_standalone_batches_per_device: -1
|
| 111 |
+
auto_enhance_reward_config:
|
| 112 |
+
dfn_score: 1.0
|
| 113 |
+
laion_aesthetic_score: 1.0
|
| 114 |
+
|
| 115 |
+
trainer:
|
| 116 |
+
log_flops: false
|
| 117 |
+
log_every_n_steps: 10
|
| 118 |
+
custom_ddp_bf16: true
|
| 119 |
+
log_seperate_modal_losses: true
|
| 120 |
+
limit_val_batches: 16
|
| 121 |
+
softmin_snr: 5
|
| 122 |
+
text_loss_weight: 1.0
|
| 123 |
+
img_loss_weight: 0.6
|
| 124 |
+
use_gradient_checkpointing: false
|
| 125 |
+
ckpt_steps: 20000
|
| 126 |
+
ckpt_every_n_minutes: 180
|
| 127 |
+
ckpt_recent_timeout_minutes: 10
|
| 128 |
+
use_custom_ema: false
|
| 129 |
+
ema: 0.0
|
| 130 |
+
fsdp: true
|
| 131 |
+
restart_on_failure: true
|
| 132 |
+
eval_on_start: false
|
| 133 |
+
val_check_interval: 100000000000
|
| 134 |
+
scale_lr_by_batch_size: false
|
| 135 |
+
watch_gradients: false
|
| 136 |
+
compile: true
|
| 137 |
+
mask_entire_modality: 0.15
|
| 138 |
+
compile_flag_pos_emb: true
|
| 139 |
+
multimodal_batches: true
|
| 140 |
+
optim:
|
| 141 |
+
lr: 0.0001
|
| 142 |
+
sampling:
|
| 143 |
+
steps: 128
|
| 144 |
+
num_sample_batches: 2
|
| 145 |
+
wandb:
|
| 146 |
+
mode: online
|
| 147 |
+
checkpointing:
|
| 148 |
+
checkpoints_total_limit: 10
|
| 149 |
+
use_automatic_naming: false
|
| 150 |
+
lr_scheduler:
|
| 151 |
+
num_warmup_steps: 10000
|
configs/experiments/large_scale_train_high_res.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# @package _global_
|
| 3 |
+
|
| 4 |
+
data:
|
| 5 |
+
resolution: 512
|
| 6 |
+
data_dir_train:
|
| 7 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
| 8 |
+
weight: 1
|
| 9 |
+
name: HPDv2_image_reward_512
|
| 10 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
|
| 11 |
+
weight: 2
|
| 12 |
+
name: pick_score_sac_prompts_v1_v2_v3_512
|
| 13 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
|
| 14 |
+
weight: 0.5
|
| 15 |
+
name: datacomp1b_7_512
|
| 16 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/text/slimpajama6b
|
| 17 |
+
weight: 2.5
|
| 18 |
+
name: slimpajama6b
|
| 19 |
+
data_dir_val:
|
| 20 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
|
| 21 |
+
weight: 1.0
|
| 22 |
+
name: gecko_eval_512
|
| 23 |
+
|
| 24 |
+
trainer:
|
| 25 |
+
text_loss_weight: 1.0
|
| 26 |
+
img_loss_weight: 0.5
|
| 27 |
+
force_full_attention_mask: true
|
| 28 |
+
mask_entire_modality: 0.1
|
| 29 |
+
|
| 30 |
+
loader:
|
| 31 |
+
pin_memory: false
|
| 32 |
+
num_workers: 4
|
| 33 |
+
num_eval_workers: 4
|
| 34 |
+
|
| 35 |
+
lr_scheduler:
|
| 36 |
+
num_warmup_steps: 5000
|
| 37 |
+
|
| 38 |
+
model:
|
| 39 |
+
linear_factor: 2
|
configs/experiments/large_scale_train_high_res_inference.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
use_token_dataset: true
|
| 5 |
+
disable_mask_after_eos: true
|
| 6 |
+
move_tensordict_to_shm: false
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
compile_flag_pos_emb: true
|
| 10 |
+
multimodal_batches: true
|
| 11 |
+
allow_null_sigma: true
|
| 12 |
+
|
| 13 |
+
eval:
|
| 14 |
+
num_sample_batches: 1
|
| 15 |
+
num_random_masking: 0
|
| 16 |
+
num_masking_viz_batches: 0
|
| 17 |
+
limit_val_batches_manual: 1
|
| 18 |
+
num_uncond_sample_batches: 10
|
| 19 |
+
eval_large_batch: 10
|
| 20 |
+
val_with_train_data: false
|
| 21 |
+
maskgit_r_temp: 4.5
|
| 22 |
+
half_uncond: false
|
| 23 |
+
cfg: 3.0
|
| 24 |
+
static_img_txt_demo: true
|
| 25 |
+
visualize_sample: true
|
| 26 |
+
|
| 27 |
+
sampling:
|
| 28 |
+
steps: 50
|
| 29 |
+
max_sampling_steps: 50
|
| 30 |
+
predictor: "maskgit"
|
configs/experiments/large_scale_train_high_res_interleaved.yaml
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# @package _global_
|
| 3 |
+
|
| 4 |
+
data:
|
| 5 |
+
move_tensordict_to_shm: false
|
| 6 |
+
enable_cuda_in_tensordict_collate: false
|
| 7 |
+
force_mp_spawn: false
|
| 8 |
+
resolution: 512
|
| 9 |
+
add_text_to_weighted_sampler: false
|
| 10 |
+
|
| 11 |
+
add_image_gen_tokens: true
|
| 12 |
+
use_packing_collate: true
|
| 13 |
+
dynamic_packing_lengths: true
|
| 14 |
+
remove_txt_img_padding: true
|
| 15 |
+
require_sample_ids: true
|
| 16 |
+
block_size: ${model.length}
|
| 17 |
+
disable_mask_after_eos: true
|
| 18 |
+
add_image_token: true
|
| 19 |
+
use_slow_tokenizer: true
|
| 20 |
+
force_seed: true
|
| 21 |
+
|
| 22 |
+
data_dir_train:
|
| 23 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
| 24 |
+
weight: 0.5
|
| 25 |
+
name: HPDv2_image_reward_v1_v2_v3 # 3593248
|
| 26 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
|
| 27 |
+
weight: 1.0
|
| 28 |
+
name: pick_score_sac_prompts_v1_v2_v3_512 # 9330810
|
| 29 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
|
| 30 |
+
weight: 1.0
|
| 31 |
+
name: pixelprose_tokens # 6627589
|
| 32 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cambrian_10m_v5
|
| 33 |
+
weight: 1.0
|
| 34 |
+
name: cambrian_10m_v5 # 8215264
|
| 35 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
|
| 36 |
+
weight: 1.0
|
| 37 |
+
name: datacomp1b_7_512 # 23955209
|
| 38 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
|
| 39 |
+
weight: 0.5
|
| 40 |
+
name: datacomp_1b_datacomp1b_2_tokens # 10161505
|
| 41 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
|
| 42 |
+
weight: 0.5
|
| 43 |
+
name: datacomp_1b_datacomp1b_4_tokens # 27895717
|
| 44 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/mmc4_fewer_faces_v0
|
| 45 |
+
weight: 2.0
|
| 46 |
+
name: mmc4_fewer_faces_v0 # 22605524
|
| 47 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
|
| 48 |
+
weight: 0.5
|
| 49 |
+
name: datacomp_1b_datacomp1b_5_tokens
|
| 50 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
|
| 51 |
+
weight: 0.5
|
| 52 |
+
name: datacomp_1b_datacomp1b_0_tokens
|
| 53 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
|
| 54 |
+
weight: 0.5
|
| 55 |
+
name: datacomp_1b_datacomp1b_1_tokens
|
| 56 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cosmopedia_2_v0
|
| 57 |
+
weight: 1.0
|
| 58 |
+
name: cosmopedia_v2
|
| 59 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/fineweb_edu_dedup_v0
|
| 60 |
+
weight: 1.0
|
| 61 |
+
name: fineweb_edu_dedup
|
| 62 |
+
data_dir_val:
|
| 63 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
|
| 64 |
+
weight: 1.0
|
| 65 |
+
name: gecko_eval_512
|
| 66 |
+
|
| 67 |
+
trainer:
|
| 68 |
+
text_loss_weight: 1.0
|
| 69 |
+
img_loss_weight: 0.2
|
| 70 |
+
mask_entire_modality: 0.2
|
| 71 |
+
|
| 72 |
+
force_full_attention_mask: false
|
| 73 |
+
force_full_attention_mask_loss_only: false
|
| 74 |
+
disable_all_eval_generation: true
|
| 75 |
+
interleaved: true
|
| 76 |
+
interleaved_training_flex_attention: true
|
| 77 |
+
force_convert_to_dict: true
|
| 78 |
+
val_check_interval: -1
|
| 79 |
+
use_gradient_checkpointing: true
|
| 80 |
+
disable_all_checkpointing: false
|
| 81 |
+
set_max_txt_loss_ratio: true
|
| 82 |
+
gradient_clip_val: 1.0
|
| 83 |
+
skip_early_checkpointing: false
|
| 84 |
+
bypass_load_from_state_dicts_if_resuming: true
|
| 85 |
+
|
| 86 |
+
loader:
|
| 87 |
+
num_workers: 4
|
| 88 |
+
num_eval_workers: 4
|
| 89 |
+
|
| 90 |
+
lr_scheduler:
|
| 91 |
+
num_warmup_steps: 5000
|
| 92 |
+
|
| 93 |
+
model:
|
| 94 |
+
linear_factor: 2
|
| 95 |
+
use_flex_attention: true
|
| 96 |
+
use_spda_attn: true
|
| 97 |
+
|
| 98 |
+
length: 1536
|
| 99 |
+
txt_length: ${.length}
|
| 100 |
+
img_length: ${.length}
|
| 101 |
+
|
| 102 |
+
eval:
|
| 103 |
+
generate_samples: false
|
| 104 |
+
disable_visualization: true
|
| 105 |
+
|
configs/experiments/maskgit.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
downscale_ratio: 16
|
| 5 |
+
image_vocab_size: 1024
|
| 6 |
+
vae_type: maskgit
|
configs/experiments/master_eval.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: eval
|
| 4 |
+
|
| 5 |
+
eval:
|
| 6 |
+
fid_samples: 4096
|
| 7 |
+
max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
|
| 8 |
+
compute_generative_perplexity: true
|
| 9 |
+
generate_samples: true
|
| 10 |
+
log_every_n_fid: 1
|
| 11 |
+
log_every_n_evals: 1
|
| 12 |
+
class_conditional_fid: false
|
| 13 |
+
txt_conditional_fid: true
|
| 14 |
+
calculate_clip_score: true
|
| 15 |
+
cfg: 5
|
| 16 |
+
num_sample_batches: 2
|
| 17 |
+
compute_standalone_mauve: false
|
| 18 |
+
mauve_num_samples: -1
|
| 19 |
+
set_random_gen_seed: true
|
| 20 |
+
# gen_ppl_eval_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
|
| 21 |
+
compute_img_to_txt_mauve_clip: true
|
| 22 |
+
compute_img_to_txt_mauve_during_unconditional_fid: true
|
| 23 |
+
force_eval_uncond: true
|
| 24 |
+
ablation_config: true
|
| 25 |
+
compute_val_metrics_standalone: true
|
| 26 |
+
num_val_metrics_standalone_samples: 2000
|
| 27 |
+
|
| 28 |
+
trainer:
|
| 29 |
+
disable_all_eval_generation: false
|
| 30 |
+
force_after_eos_padding: true
|
| 31 |
+
|
| 32 |
+
model:
|
| 33 |
+
image_model_fid_eval: true
|
| 34 |
+
use_kv_cache: ${is_ar:${parameterization}}
|
| 35 |
+
|
| 36 |
+
loader:
|
| 37 |
+
batch_size: 64
|
| 38 |
+
eval_batch_size: 64
|
| 39 |
+
num_workers: 0
|
| 40 |
+
num_eval_workers: 1
|
| 41 |
+
|
| 42 |
+
sampling:
|
| 43 |
+
steps: ${model.length}
|
| 44 |
+
max_sampling_steps: ${sampling.steps}
|
| 45 |
+
sampling_step_frac: null
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
data:
|
| 49 |
+
fid_dataset: null
|
configs/experiments/mscoco_fid.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
disable_text_modality: false
|
| 5 |
+
keep_hf_dataset_in_memory: true
|
| 6 |
+
aggressive_aug: false
|
| 7 |
+
n_duplicate_train: null
|
| 8 |
+
n_duplicate_val: null
|
| 9 |
+
data_dir_train: []
|
| 10 |
+
data_dir_val: []
|
| 11 |
+
image_data_train: ${data.image_data_val}
|
| 12 |
+
image_data_val:
|
| 13 |
+
- val: sayakpaul/coco-30-val-2014
|
| 14 |
+
weight: -1
|
| 15 |
+
name: mscoco_val
|
| 16 |
+
tokenize_vqvae_in_dataloader: false
|
| 17 |
+
raw_images: true
|
| 18 |
+
|
| 19 |
+
eval:
|
| 20 |
+
compute_generative_perplexity: true
|
| 21 |
+
generate_samples: true
|
configs/experiments/paired_standalone_fid_eval.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: eval
|
| 4 |
+
debug: true
|
| 5 |
+
|
| 6 |
+
eval:
|
| 7 |
+
fid_samples: 4096
|
| 8 |
+
max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
|
| 9 |
+
compute_generative_perplexity: false
|
| 10 |
+
generate_samples: false
|
| 11 |
+
log_every_n_fid: 1
|
| 12 |
+
log_every_n_evals: 1
|
| 13 |
+
class_conditional_fid: false
|
| 14 |
+
txt_conditional_fid: true
|
| 15 |
+
calculate_clip_score: true
|
| 16 |
+
cfg: 5
|
| 17 |
+
|
| 18 |
+
model:
|
| 19 |
+
image_model_fid_eval: true
|
| 20 |
+
|
| 21 |
+
loader:
|
| 22 |
+
eval_batch_size: 32
|
| 23 |
+
|
| 24 |
+
sampling:
|
| 25 |
+
steps: ${model.length}
|
| 26 |
+
max_sampling_steps: ${model.length}
|
| 27 |
+
|
| 28 |
+
data:
|
| 29 |
+
keep_hf_dataset_in_memory: false
|
configs/experiments/small_scale_train.yaml
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- vq16_magvit
|
| 5 |
+
- override /model: small
|
| 6 |
+
- override /lr_scheduler: constant_warmup_cosine_decay
|
| 7 |
+
|
| 8 |
+
model:
|
| 9 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
| 10 |
+
txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
|
| 11 |
+
length: ${eval:'${.txt_length} + ${.img_length}'}
|
| 12 |
+
image_model: true
|
| 13 |
+
text_model: true
|
| 14 |
+
unified_model: true
|
| 15 |
+
image_model_fid_eval: false
|
| 16 |
+
force_argmax_valid_indices: true
|
| 17 |
+
use_pretrained_img_emb: false
|
| 18 |
+
codebook_embed_dim: 256
|
| 19 |
+
qk_norm: true
|
| 20 |
+
norm_type: rms
|
| 21 |
+
sandwich_normalization: true
|
| 22 |
+
zero_linear_init: false
|
| 23 |
+
modality_embed: true
|
| 24 |
+
rope_2d: false
|
| 25 |
+
use_spda_attn: true
|
| 26 |
+
force_optimized_native_attn: true
|
| 27 |
+
freeze_txt_emb: false
|
| 28 |
+
add_labels: null
|
| 29 |
+
txt_dropout: null
|
| 30 |
+
text_vocab_size: 32001
|
| 31 |
+
|
| 32 |
+
data:
|
| 33 |
+
train: combined_tokens
|
| 34 |
+
valid: ${.train}
|
| 35 |
+
n_duplicate_train: null
|
| 36 |
+
wrap: true
|
| 37 |
+
streaming: false
|
| 38 |
+
precache: false
|
| 39 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
| 40 |
+
resolution: 256
|
| 41 |
+
block_size: 128
|
| 42 |
+
n_val_samples: null
|
| 43 |
+
unpaired: false
|
| 44 |
+
n_duplicate_val: null
|
| 45 |
+
save_train_dataloader: true
|
| 46 |
+
save_validation_dataloader: true
|
| 47 |
+
iterable: false
|
| 48 |
+
webdataset_iterable: false
|
| 49 |
+
webdataset_indexed: false
|
| 50 |
+
dataset_type: null
|
| 51 |
+
tokens_flip_collate: false
|
| 52 |
+
n_train_samples: null
|
| 53 |
+
raw_data_dir: null
|
| 54 |
+
tokenizers_parallelism: false
|
| 55 |
+
token_data_dir: null
|
| 56 |
+
force_disable_shuffle: false
|
| 57 |
+
keep_tensordict_on_disk: true
|
| 58 |
+
use_custom_tensordict_collate: true
|
| 59 |
+
force_mp_spawn: false
|
| 60 |
+
enable_cuda_in_tensordict_collate: false
|
| 61 |
+
use_weighted_tensordict_sampler: true
|
| 62 |
+
fraction_txt_data: 0.0
|
| 63 |
+
tokenize_vqvae_in_dataloader: false
|
| 64 |
+
use_token_dataset: true
|
| 65 |
+
image_dataset: tglcourse/lsun_church_train
|
| 66 |
+
image_data_train: null
|
| 67 |
+
image_data_val: null
|
| 68 |
+
keep_hf_dataset_in_memory: true
|
| 69 |
+
allow_label: false
|
| 70 |
+
disable_text_modality: true
|
| 71 |
+
force_raw_train_images: false
|
| 72 |
+
aggressive_aug: true
|
| 73 |
+
allow_aug_vqvae_dataloader: true
|
| 74 |
+
move_tensordict_to_shm: false
|
| 75 |
+
data_dir_train:
|
| 76 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
|
| 77 |
+
weight: -1
|
| 78 |
+
name: datacomp1b_8_magvit_train
|
| 79 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
|
| 80 |
+
weight: -1
|
| 81 |
+
name: cc12m_tokens_train_256
|
| 82 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
|
| 83 |
+
weight: -1
|
| 84 |
+
name: HPDv2_image_reward_v1_v2_v3_magvit
|
| 85 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
|
| 86 |
+
weight: -1
|
| 87 |
+
name: pick_score_sac_prompts_v1_v2_v3_magvit
|
| 88 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
|
| 89 |
+
weight: -1
|
| 90 |
+
name: datacomp1b_0_1_6_magvit
|
| 91 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
|
| 92 |
+
weight: -1
|
| 93 |
+
name: laion400m_magvit_part_0
|
| 94 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
|
| 95 |
+
weight: -1
|
| 96 |
+
name: laion400m_magvit_part_1
|
| 97 |
+
data_dir_val:
|
| 98 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
|
| 99 |
+
weight: 1
|
| 100 |
+
name: datacomp1b_8_magvit_val
|
| 101 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
|
| 102 |
+
weight: 1
|
| 103 |
+
name: cc12m_tokens_val_256
|
| 104 |
+
|
| 105 |
+
eval:
|
| 106 |
+
generate_samples: true
|
| 107 |
+
compute_generative_perplexity: true
|
| 108 |
+
log_every_n_evals: 10
|
| 109 |
+
log_every_n_fid: 20
|
| 110 |
+
limit_val_batches_manual: 16
|
| 111 |
+
perplexity_batch_size: ${loader.eval_batch_size}
|
| 112 |
+
num_masking_viz_batches: -1
|
| 113 |
+
cfg: null
|
| 114 |
+
class_conditional_fid: false
|
| 115 |
+
force_cfg_value: true
|
| 116 |
+
split_cfg_batches: true
|
| 117 |
+
max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
|
| 118 |
+
fid_mode: clean
|
| 119 |
+
clean_fid_precomputed_name: lsun_church
|
| 120 |
+
clean_fid_precomputed_split: trainfull
|
| 121 |
+
clean_fid_precomputed_res: 256
|
| 122 |
+
|
| 123 |
+
trainer:
|
| 124 |
+
log_every_n_steps: 10
|
| 125 |
+
val_check_interval: 1000
|
| 126 |
+
custom_ddp_bf16: true
|
| 127 |
+
scale_lr_by_batch_size: false
|
| 128 |
+
limit_val_batches: 16
|
| 129 |
+
use_gradient_checkpointing: false
|
| 130 |
+
log_seperate_modal_losses: true
|
| 131 |
+
softmin_snr: 5
|
| 132 |
+
text_loss_weight: 1.0
|
| 133 |
+
img_loss_weight: null
|
| 134 |
+
low_precision_loss: false
|
| 135 |
+
compile: true
|
| 136 |
+
multimodal_batches: true
|
| 137 |
+
compile_fullgraph: false
|
| 138 |
+
log_grad_norm_every_n_steps: 10
|
| 139 |
+
mask_entire_modality: 0.1
|
| 140 |
+
force_shift_image_batches: false
|
| 141 |
+
ckpt_steps: 10000
|
| 142 |
+
ckpt_every_n_minutes: -1
|
| 143 |
+
ignore_text_in_unified: false
|
| 144 |
+
disable_all_eval_generation: true
|
| 145 |
+
eval_on_start: false
|
| 146 |
+
ckpt_model_only: false
|
| 147 |
+
ema: 0.0
|
| 148 |
+
use_custom_ema: false
|
| 149 |
+
log_flops: false
|
| 150 |
+
disable_distributed_torchmetrics: true
|
| 151 |
+
restart_on_failure: true
|
| 152 |
+
force_null_sigma: true
|
| 153 |
+
allow_null_sigma: true
|
| 154 |
+
compile_flag_pos_emb: true
|
| 155 |
+
add_label: false
|
| 156 |
+
first_token_dropout: null
|
| 157 |
+
force_shift_raw_image_batches: true
|
| 158 |
+
txt_dropout: 0.1
|
| 159 |
+
force_full_attention_mask_loss_only: true
|
| 160 |
+
|
| 161 |
+
optim:
|
| 162 |
+
lr: 0.0003
|
| 163 |
+
weight_decay: 0.05
|
| 164 |
+
|
| 165 |
+
loader:
|
| 166 |
+
batch_size: 64
|
| 167 |
+
eval_batch_size: ${loader.batch_size}
|
| 168 |
+
num_workers: 4
|
| 169 |
+
desired_global_batch_size: 512
|
| 170 |
+
persistent_workers: true
|
| 171 |
+
pin_memory: true
|
| 172 |
+
num_eval_workers: 1
|
| 173 |
+
|
| 174 |
+
sampling:
|
| 175 |
+
steps: ${model.length}
|
| 176 |
+
num_sample_batches: 2
|
| 177 |
+
max_sampling_steps: ${model.length}
|
| 178 |
+
|
| 179 |
+
wandb:
|
| 180 |
+
mode: online
|
| 181 |
+
|
| 182 |
+
lr_scheduler:
|
| 183 |
+
num_warmup_steps: 5000
|
| 184 |
+
num_training_steps: ${trainer.max_steps}
|
| 185 |
+
|
| 186 |
+
checkpointing:
|
| 187 |
+
checkpoints_total_limit: 10
|
configs/experiments/small_scale_train_caching.yaml
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- /model: small
|
| 5 |
+
|
| 6 |
+
model:
|
| 7 |
+
downscale_ratio: 16
|
| 8 |
+
image_vocab_size: 8192
|
| 9 |
+
vae_type: magvit
|
| 10 |
+
use_custom_vae_ckpt: null
|
| 11 |
+
custom_vae_name: null
|
| 12 |
+
img_length: 256
|
| 13 |
+
txt_length: 128
|
| 14 |
+
image_model: true
|
| 15 |
+
text_model: true
|
| 16 |
+
unified_model: true
|
| 17 |
+
image_model_fid_eval: false
|
| 18 |
+
force_argmax_valid_indices: true
|
| 19 |
+
use_pretrained_img_emb: false
|
| 20 |
+
codebook_embed_dim: 256
|
| 21 |
+
qk_norm: true
|
| 22 |
+
norm_type: rms
|
| 23 |
+
sandwich_normalization: true
|
| 24 |
+
zero_linear_init: false
|
| 25 |
+
modality_embed: true
|
| 26 |
+
rope_2d: false
|
| 27 |
+
use_spda_attn: true
|
| 28 |
+
force_optimized_native_attn: true
|
| 29 |
+
freeze_txt_emb: false
|
| 30 |
+
add_labels: null
|
| 31 |
+
txt_dropout: null
|
| 32 |
+
text_vocab_size: 32001
|
| 33 |
+
use_flex_attention: true
|
| 34 |
+
flex_attention_txt_masking_prob: 0.1
|
| 35 |
+
flex_attention_img_masking_prob: 0.1
|
| 36 |
+
linear_factor: 1
|
| 37 |
+
data:
|
| 38 |
+
train: combined_tokens
|
| 39 |
+
valid: ${.train}
|
| 40 |
+
n_duplicate_train: null
|
| 41 |
+
wrap: true
|
| 42 |
+
streaming: false
|
| 43 |
+
precache: false
|
| 44 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
| 45 |
+
resolution: 256
|
| 46 |
+
block_size: 128
|
| 47 |
+
n_val_samples: null
|
| 48 |
+
unpaired: false
|
| 49 |
+
n_duplicate_val: null
|
| 50 |
+
save_train_dataloader: true
|
| 51 |
+
save_validation_dataloader: true
|
| 52 |
+
iterable: false
|
| 53 |
+
webdataset_iterable: false
|
| 54 |
+
webdataset_indexed: false
|
| 55 |
+
dataset_type: null
|
| 56 |
+
tokens_flip_collate: false
|
| 57 |
+
n_train_samples: null
|
| 58 |
+
raw_data_dir: null
|
| 59 |
+
tokenizers_parallelism: false
|
| 60 |
+
token_data_dir: null
|
| 61 |
+
force_disable_shuffle: false
|
| 62 |
+
keep_tensordict_on_disk: true
|
| 63 |
+
use_custom_tensordict_collate: true
|
| 64 |
+
force_mp_spawn: false
|
| 65 |
+
enable_cuda_in_tensordict_collate: false
|
| 66 |
+
use_weighted_tensordict_sampler: true
|
| 67 |
+
fraction_txt_data: 0.0
|
| 68 |
+
data_dir_train:
|
| 69 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
|
| 70 |
+
weight: -1
|
| 71 |
+
name: datacomp1b_8_magvit_train
|
| 72 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
|
| 73 |
+
weight: -1
|
| 74 |
+
name: cc12m_tokens_train_256
|
| 75 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
|
| 76 |
+
weight: -1
|
| 77 |
+
name: HPDv2_image_reward_v1_v2_v3_magvit
|
| 78 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
|
| 79 |
+
weight: -1
|
| 80 |
+
name: pick_score_sac_prompts_v1_v2_v3_magvit
|
| 81 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
|
| 82 |
+
weight: -1
|
| 83 |
+
name: datacomp1b_0_1_6_magvit
|
| 84 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
|
| 85 |
+
weight: -1
|
| 86 |
+
name: laion400m_magvit_part_0
|
| 87 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
|
| 88 |
+
weight: -1
|
| 89 |
+
name: laion400m_magvit_part_1
|
| 90 |
+
data_dir_val:
|
| 91 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
|
| 92 |
+
weight: 1
|
| 93 |
+
name: datacomp1b_8_magvit_val
|
| 94 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
|
| 95 |
+
weight: 1
|
| 96 |
+
name: cc12m_tokens_val_256
|
| 97 |
+
tokenize_vqvae_in_dataloader: false
|
| 98 |
+
val:
|
| 99 |
+
.train: null
|
| 100 |
+
use_token_dataset: true
|
| 101 |
+
image_dataset: tglcourse/lsun_church_train
|
| 102 |
+
image_data_train: null
|
| 103 |
+
image_data_val: null
|
| 104 |
+
keep_hf_dataset_in_memory: true
|
| 105 |
+
allow_label: false
|
| 106 |
+
disable_text_modality: true
|
| 107 |
+
force_raw_train_images: false
|
| 108 |
+
aggressive_aug: true
|
| 109 |
+
allow_aug_vqvae_dataloader: true
|
| 110 |
+
move_tensordict_to_shm: false
|
| 111 |
+
force_full_attention_mask: false
|
| 112 |
+
eval:
|
| 113 |
+
generate_samples: false
|
| 114 |
+
compute_generative_perplexity: false
|
| 115 |
+
log_every_n_evals: 10
|
| 116 |
+
log_every_n_fid: 20
|
| 117 |
+
limit_val_batches_manual: 16
|
| 118 |
+
perplexity_batch_size: ${loader.eval_batch_size}
|
| 119 |
+
num_masking_viz_batches: -1
|
| 120 |
+
max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
|
| 121 |
+
cfg: null
|
| 122 |
+
class_conditional_fid: false
|
| 123 |
+
force_cfg_value: true
|
| 124 |
+
split_cfg_batches: true
|
| 125 |
+
fid_mode: clean
|
| 126 |
+
clean_fid_precomputed_name: lsun_church
|
| 127 |
+
clean_fid_precomputed_split: trainfull
|
| 128 |
+
clean_fid_precomputed_res: 256
|
| 129 |
+
trainer:
|
| 130 |
+
log_every_n_steps: 10
|
| 131 |
+
val_check_interval: 1000
|
| 132 |
+
custom_ddp_bf16: true
|
| 133 |
+
scale_lr_by_batch_size: false
|
| 134 |
+
limit_val_batches: 16
|
| 135 |
+
use_gradient_checkpointing: false
|
| 136 |
+
log_seperate_modal_losses: true
|
| 137 |
+
softmin_snr: 5
|
| 138 |
+
text_loss_weight: 1.0
|
| 139 |
+
img_loss_weight: null
|
| 140 |
+
low_precision_loss: false
|
| 141 |
+
compile: false
|
| 142 |
+
multimodal_batches: true
|
| 143 |
+
compile_fullgraph: false
|
| 144 |
+
log_grad_norm_every_n_steps: 10
|
| 145 |
+
mask_entire_modality: 0.1
|
| 146 |
+
force_shift_image_batches: false
|
| 147 |
+
ckpt_steps: 10000
|
| 148 |
+
ckpt_every_n_minutes: -1
|
| 149 |
+
ignore_text_in_unified: false
|
| 150 |
+
disable_all_eval_generation: false
|
| 151 |
+
eval_on_start: false
|
| 152 |
+
ckpt_model_only: false
|
| 153 |
+
ema: 0.0
|
| 154 |
+
use_custom_ema: false
|
| 155 |
+
log_flops: false
|
| 156 |
+
disable_distributed_torchmetrics: true
|
| 157 |
+
restart_on_failure: true
|
| 158 |
+
force_null_sigma: true
|
| 159 |
+
allow_null_sigma: true
|
| 160 |
+
compile_flag_pos_emb: true
|
| 161 |
+
add_label: false
|
| 162 |
+
first_token_dropout: null
|
| 163 |
+
force_shift_raw_image_batches: true
|
| 164 |
+
txt_dropout: 0.1
|
| 165 |
+
disable_ddp_optimizer: true
|
| 166 |
+
optim:
|
| 167 |
+
lr: 0.0003
|
| 168 |
+
weight_decay: 0.05
|
| 169 |
+
loader:
|
| 170 |
+
batch_size: 64
|
| 171 |
+
eval_batch_size: ${loader.batch_size}
|
| 172 |
+
num_workers: 1
|
| 173 |
+
desired_global_batch_size: 512
|
| 174 |
+
persistent_workers: true
|
| 175 |
+
pin_memory: true
|
| 176 |
+
num_eval_workers: 1
|
| 177 |
+
sampling:
|
| 178 |
+
steps: ${model.length}
|
| 179 |
+
num_sample_batches: 2
|
| 180 |
+
max_sampling_steps: ${model.length}
|
| 181 |
+
wandb:
|
| 182 |
+
mode: online
|
| 183 |
+
lr_scheduler:
|
| 184 |
+
num_warmup_steps: 5000
|
| 185 |
+
checkpointing:
|
| 186 |
+
checkpoints_total_limit: 4
|
configs/experiments/small_text_only.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- lsun_text8_exp_2
|
| 5 |
+
- owt_only
|
| 6 |
+
- override /model: small
|
| 7 |
+
|
| 8 |
+
backbone: dit
|
| 9 |
+
|
| 10 |
+
loader:
|
| 11 |
+
batch_size: 64
|
| 12 |
+
|
| 13 |
+
trainer:
|
| 14 |
+
val_check_interval: 10000
|
| 15 |
+
ckpt_steps: 10000
|
| 16 |
+
softmin_snr: null
|
| 17 |
+
|
| 18 |
+
optim:
|
| 19 |
+
fused: true
|
| 20 |
+
weight_decay: 0.03
|
| 21 |
+
|
| 22 |
+
sampling:
|
| 23 |
+
num_sample_batches: 4
|
| 24 |
+
max_sampling_steps: 256
|
| 25 |
+
|
| 26 |
+
model:
|
| 27 |
+
txt_length: 1024
|
| 28 |
+
|
configs/experiments/standalone_fid_eval.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: eval
|
| 4 |
+
debug: true
|
| 5 |
+
|
| 6 |
+
eval:
|
| 7 |
+
max_num_fid_batches_per_device: ${eval:'4096 // (${trainer.devices} * ${loader.eval_batch_size})'}
|
| 8 |
+
compute_generative_perplexity: false
|
| 9 |
+
generate_samples: false
|
| 10 |
+
log_every_n_fid: 1
|
| 11 |
+
log_every_n_evals: 1
|
| 12 |
+
|
| 13 |
+
loader:
|
| 14 |
+
eval_batch_size: 32
|
| 15 |
+
|
| 16 |
+
sampling:
|
| 17 |
+
steps: 500
|
| 18 |
+
max_sampling_steps: 500
|
configs/experiments/titok.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
resolution: 256
|
| 5 |
+
downscale_ratio: 16
|
| 6 |
+
|
| 7 |
+
model:
|
| 8 |
+
vae_type: titok
|
configs/experiments/titok_sl256.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
resolution: 256
|
| 5 |
+
|
| 6 |
+
model:
|
| 7 |
+
vae_type: titok
|
configs/experiments/txt_only.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
streaming: False
|
| 5 |
+
unpaired: false
|
| 6 |
+
|
| 7 |
+
trainer:
|
| 8 |
+
img_loss_weight: null
|
| 9 |
+
text_loss_weight: null
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
use_pretrained_img_emb: false
|
| 13 |
+
image_model_fid_eval: false
|
| 14 |
+
unified_model: false
|
| 15 |
+
image_model: false
|
| 16 |
+
txt_length: 256
|
| 17 |
+
img_length: 0
|
| 18 |
+
|
| 19 |
+
eval:
|
| 20 |
+
log_every_n_evals: -1
|
| 21 |
+
log_every_n_fid: -1
|
configs/experiments/unified.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
zero_shot_eval_dataset: "nlphuji/flickr30k"
|
| 5 |
+
precache: False
|
| 6 |
+
tokenizers_parallelism: False # parallelism causes some weird error
|
| 7 |
+
n_val_samples: 2048
|
| 8 |
+
block_size: 128
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
unified_model: True
|
| 12 |
+
text_model: true
|
| 13 |
+
|
| 14 |
+
checkpointing:
|
| 15 |
+
resume_from_ckpt: True
|
| 16 |
+
load_from_text_model: "ckpts/unidisc-owt/model.safetensors"
|
| 17 |
+
|
| 18 |
+
loader:
|
| 19 |
+
batch_size: 12
|
| 20 |
+
|
| 21 |
+
trainer:
|
| 22 |
+
val_check_interval: 2000
|
| 23 |
+
log_seperate_modal_losses: true
|
configs/experiments/vq16.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
downscale_ratio: 16
|
| 5 |
+
image_vocab_size: 16384
|
| 6 |
+
vae_type: VQ-16
|
| 7 |
+
use_custom_vae_ckpt: null
|
| 8 |
+
custom_vae_name: null
|
| 9 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
configs/experiments/vq16_1024.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
downscale_ratio: 16
|
| 5 |
+
image_vocab_size: 1024
|
| 6 |
+
codebook_embed_dim: 256
|
| 7 |
+
vae_type: VQ-16
|
| 8 |
+
use_custom_vae_ckpt: ${oc.env:DIFFUSION_DATA_DIR}/ckpts/2024-07-03-01-10-53_022-VQ-16_0042000.pt
|
configs/experiments/vq16_magvit.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
downscale_ratio: 16
|
| 5 |
+
image_vocab_size: 8192
|
| 6 |
+
vae_type: magvit
|
| 7 |
+
use_custom_vae_ckpt: null
|
| 8 |
+
custom_vae_name: null
|
| 9 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
configs/experiments/vq16_t2i.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
downscale_ratio: 16
|
| 5 |
+
image_vocab_size: 16384
|
| 6 |
+
vae_type: VQ-16
|
| 7 |
+
use_custom_vae_ckpt: ${get_repo_dir:}/ckpts/vq_ds16_t2i.pt
|
| 8 |
+
custom_vae_name: _t2i
|
| 9 |
+
codebook_embed_dim: 8
|
| 10 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
configs/experiments/webdataset.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
train: datacomp1b_indexed
|
| 5 |
+
valid: ${.train}
|
| 6 |
+
|
| 7 |
+
iterable: false
|
| 8 |
+
webdataset_iterable: false
|
| 9 |
+
webdataset_indexed: true
|
| 10 |
+
unpaired: false
|
| 11 |
+
dataset_type: null
|
| 12 |
+
tokens_flip_collate: false
|
configs/experiments/zero_shot_eval.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
mode: zero-shot-eval
|
| 4 |
+
|
| 5 |
+
data:
|
| 6 |
+
# train: "nlphuji/flickr30k"
|
| 7 |
+
train: "facebook/winoground"
|
| 8 |
+
precache: False
|
| 9 |
+
tokenizers_parallelism: False # parallelism causes some weird error
|
| 10 |
+
n_val_samples: 2048
|
| 11 |
+
block_size: 128
|
| 12 |
+
disable_text_modality: false
|
| 13 |
+
|
| 14 |
+
eval:
|
| 15 |
+
cfg: 5
|
| 16 |
+
compute_val_metrics_standalone: false
|
| 17 |
+
compute_img_to_txt_mauve_clip: false
|
| 18 |
+
|
| 19 |
+
loader:
|
| 20 |
+
batch_size: 16
|
| 21 |
+
eval_batch_size: 16
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
model:
|
| 25 |
+
unified_model: True
|
| 26 |
+
text_model: true
|
| 27 |
+
image_model: true
|
| 28 |
+
vae_type: magvit
|
| 29 |
+
force_optimized_native_attn: false
|
configs/lr_scheduler/constant_warmup.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
| 2 |
+
num_warmup_steps: 2500
|
configs/lr_scheduler/constant_warmup_cosine_decay.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: transformers.get_cosine_schedule_with_warmup
|
| 2 |
+
num_warmup_steps: 2500
|
| 3 |
+
num_training_steps: 1000000
|
configs/lr_scheduler/cosine_decay_warmup.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: utils.CosineDecayWarmupLRScheduler
|
| 2 |
+
t_in_epochs: False
|
| 3 |
+
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
|
| 4 |
+
warmup_prefix: True
|
| 5 |
+
warmup_lr_init: 1e-6
|
| 6 |
+
warmup_t: ${eval:0.1*${trainer.max_steps}}
|
| 7 |
+
lr_min: 1e-6
|
configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: transformers.get_cosine_with_hard_restarts_schedule_with_warmup
|
| 2 |
+
num_warmup_steps: 2500
|
| 3 |
+
num_training_steps: 1000000
|
| 4 |
+
num_cycles: 1
|
configs/model/extra_large.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: extra_large
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 2048
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 24
|
| 7 |
+
n_heads: 16
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
tie_word_embeddings: False
|
configs/model/large.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: large
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 1280
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
base_n_blocks: 28
|
| 7 |
+
# We try to roughly match parameter count
|
| 8 |
+
n_blocks: ${adjust_n_blocks:}
|
| 9 |
+
n_heads: 20
|
| 10 |
+
scale_by_sigma: True
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
tie_word_embeddings: False
|
| 13 |
+
|
| 14 |
+
# 36 1280 20
|
configs/model/medium.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: medium
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 1024
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
base_n_blocks: 24
|
| 7 |
+
# We try to roughly match parameter count
|
| 8 |
+
n_blocks: ${adjust_n_blocks:}
|
| 9 |
+
n_heads: 16
|
| 10 |
+
scale_by_sigma: True
|
| 11 |
+
dropout: 0.1
|
| 12 |
+
tie_word_embeddings: False
|
configs/model/small-ar.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: small
|
| 2 |
+
type: ddit
|
| 3 |
+
hidden_size: 768
|
| 4 |
+
cond_dim: 128
|
| 5 |
+
length: 1024
|
| 6 |
+
n_blocks: 12
|
| 7 |
+
n_heads: 12
|
| 8 |
+
scale_by_sigma: True
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
causal: True
|
| 11 |
+
tie_word_embeddings: False
|