English

Self-Distillation Through Time (SDTT)

SDTT is a distillation method for diffusion language models. Recent diffusion language models such as SEDD or MDLM achieve great results. However, because they cannot use KV-caching (non-causal architecture), it is slow to sample from them. Therefore, we devise a novel distillation method to reduce the inference latency of discrete diffusion models. After distillation, we can sample up to 8x faster than GPT-2 (that uses KV-caching). Find more details below and on our GitHub repo.

Using SDTT

  • We released 3 groups of models:
    1. The baseline students distilled with the kld, mse and tvd objectives, distilled from a model trained for 1M steps.
    2. The students from the scaling experiments, with sizes sm, md, large, distilled from models trained for 400k steps.
    3. The teachers from the scaling experiments, with sizes sm, md, large, before any distillation.
  • To load those models, first install our code:
git clone https://github.com/jdeschena/sdtt.git
cd sdtt
pip install -r requirements.txt
pip install flash-attn
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -e .
  • You can then import our models, sample and evaluate them:

Load the baseline students

from sdtt import load_small_student
student = load_small_student(loss="kld", round=7)  # load the kld student after the last distillation round
student = load_small_student(loss="mse", round=2)  # load the mse student after the second distillation round
student = load_small_student(loss="tvd", round=1)  # load the tvd student after the first distillation round

Load the students from the scaling experiment

from sdtt import load_scaling_student
student = load_scaling_student(size="sm", round=7)  # load small student after the last distillation round
student = load_scaling_student(size="md", round=1)   # load medium student after the first distillation round
student = load_scaling_student(size="large", round=3)  # load large student after the third distillation round

Load the teachers from the scaling experiment

from sdtt import load_scaling_teacher
student = load_scaling_student(size="sm",)  # load small teacher
student = load_scaling_student(size="md",)   # load medium teacher
student = load_scaling_student(size="large",)  # load large teacher

Sample from the pretrained models

from sdtt import load_small_student, load_scaling_student, load_scaling_teacher
import torch

model = load_small_student(loss="kld", round=7)  # load model, see above
model.cuda()  # put model on gpu

# Unconditional generation
tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
)
# Detokenize
uncond_text = model.tokenizer.batch_decode(tokens)

# Conditional generation, based on a prompt
# Prepare a prompt
prompt = "Today is a great day. The sun is shining,"
prompt_tokens = model.tokenizer(prompt)["input_ids"]
prompt_tokens.insert(0, model.tokenizer.bos_token_id)
prompt_tokens = torch.tensor(prompt_tokens, device="cuda")
prompt_len = len(prompt_tokens)

def project_fn(x):
    # Project the first 10 tokens of all examples to the prompt
    x[:, :prompt_len] = prompt_tokens  
    return x  # Don't forget to return

tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
    project_fn=project_fn
)

cond_text = model.tokenizer.batch_decode(tokens)

For more details, please see our github repository: SDTT

Model Details

Our small checkpoints are distilled from the MDLM checkpoints. We also release medium (424M) and large (863M) checkpoints that we pretrained ourselves.

Citation

Please cite our work using the bibtex below:

BibTeX:

@article{deschenaux2024autoregressionfastllmsselfdistillation,
        title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
        author={Deschenaux, Justin and Gulcehre, Caglar}
        eprint={2410.21035},
        archivePrefix={arXiv},
        primaryClass={cs.LG},
        url={https://arxiv.org/abs/2410.21035}, 
}

Contact

Justin Deschenaux ([email protected])

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train jdeschena/sdtt