# Pretraining a Mixtral
---
[AIGen](https://github.com/Vectorrent/aigen) is a text generation and training library, originally forked from [AITextGen](https://aitextgen.minimaxir.com/) (which is now defunct).

AIGen is also the foundation of [VTX](https://github.com/0-5788719150923125/vtx).

To use this notebook with Kaggle, one must first enable the "Internet" feature. To do so:

1. Find "Notebook options" in the sidebar on the right-hand side of this page.
2. If required, verify your phone number.
3. Choose "Internet on".
4. Connect to the P100 accelerator.
5. Setup file persistence.

Do not forget to connect to an accelerator. The P100's are better than the T4's. However, with 2x T4's available, training may benefit from DistributedDataParallel (DDP) training.

## Update system packages

In [None]:
# Now we install AIGen
!pip install 'git+https://github.com/Vectorrent/aigen.git'

# Speed up everything
# !pip install -U flash-attn --no-build-isolation

## Configuration

We would set a bunch of variables here, if we hadn't hardcoded them below for clarity.

In [None]:
# Set some variables
focus = 'mixtress'
precision = 32
attn_implementation = "eager"

# working dir
output_dir = "/kaggle/working"

# Mixtral is gated, so we use someone else's repo
base_model = 'TitanML/tiny-mixtral'
tokenizer_model = base_model
tokenizer_config = dict(
    cache_dir=f"{output_dir}/{focus}",
    padding="max_length",
    padding_side="left",
    use_fast=True,
    return_overflowing_tokens=True,
    truncation=True,
    trust_remote_code=True,
)

# to continue training from a checkpoint, False starts a fresh run
resume_training = False

## Pretraining

We want to train a model from scratch, so we import a functional config, then change it.

In [None]:
from transformers import (
    AutoConfig,
    AutoTokenizer,
    PretrainedConfig,
    PreTrainedTokenizerFast,
)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_model, **tokenizer_config)

pretrain_config = AutoConfig.from_pretrained(base_model)
overrides = {
    "model_type": base_model,
    "universal": True,
    "world_size": 23,
    "hidden_act": 'mish',
    "hidden_size": 512,
    "intermediate_size": 1024,
    "initializer_range": 0.02,
    "num_hidden_layers": 8,
    "num_attention_heads": 16,
    "num_key_value_heads": 8,
    "rope_theta": 1000000.0,
    "num_experts_per_tok": 3,
    "num_local_experts": 9,
    "vocab_size": 32000,
    "tie_word_embeddings": True,
    "router_aux_loss_coef": 0.001,
    "router_jitter_noise": 0.1,
    "sliding_window": 4096,
    "attention_dropout": 0.1
}
setattr(pretrain_config, "_name_or_path", focus)
setattr(pretrain_config, "bos_token_id", tokenizer.bos_token_id)
setattr(pretrain_config, "eos_token_id", tokenizer.eos_token_id)
for k, v in overrides.items():
    setattr(pretrain_config, k, v)
print(f"modified pretrain config:")
print(pretrain_config)

## Load the model

Here we initialize the model with random weights.

In [None]:
# Instantiate your model
import os
import shutil
from aigen import aigen

prototype = None

if resume_training:
    model = None
    model_folder = f"{output_dir}/{focus}"
    pretrain_config = None
else:
    model = base_model
    model_folder = None
    shutil.rmtree(output_dir, ignore_errors=True)

prototype = aigen(
    model=model,
    model_folder=model_folder,
    tokenizer=tokenizer,
    cache_dir=f"{output_dir}/{focus}",
    precision=precision,
    config=pretrain_config,
    device_map="cuda:0",
    attn_implementation=attn_implementation
)

print(prototype)

## Metrics

We want to log training metrics, so we install Tensorboard and expose it via ngrok. This requires an authtoken from ngrok.com, saved in Kaggle's "Add-ons>Secrets".

In [None]:
from kaggle_secrets import UserSecretsClient
secret_label = "NGROK_SECRET"
secret_value = UserSecretsClient().get_secret(secret_label)

import os
import shutil

directory = f"{output_dir}/logs"
os.makedirs(directory, exist_ok=True)

if not resume_training:
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)
        shutil.rmtree(file_path)

if secret_value:

    !pip install ngrok tensorboard

    import threading
    import subprocess

    def start_tensorboard():
        subprocess.Popen(
            ["tensorboard", "--logdir", "/kaggle/working/logs", "--bind_all", "--samples_per_plugin", "scalars=999999999"],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.STDOUT
        )

    tensorboard_thread = threading.Thread(target=start_tensorboard)
    tensorboard_thread.start()

    import ngrok

    listener = await ngrok.forward(6006, authtoken=secret_value)
    
    import time

    time.sleep(1)
    print(listener.url())

## Training

Finally, we train the model on a dataset streamed from: https://huggingface.co/datasets

In [None]:
# Train the model

import os
from lightning.pytorch import loggers

os.makedirs(f"{output_dir}/logs/{focus}", exist_ok=True)
logger = loggers.TensorBoardLogger(f"{output_dir}/logs", name=focus, default_hp_metric=True)

prototype.train(
    devices=[0],
    strategy="auto",
    streaming_data=[
        {
            "hf": True,
            "repo": "allenai/c4", 
            "split": "train",
            "val_split": "validation",
            "subset": "en.noblocklist",
            "schemas": [
                {
                    "text": ""
                }
            ],
            "buffer_size": 1000,
            "val_samples": 1000,
            "sample_rate": 1.0
        },
        {
            "hf": True,
            "repo": "HuggingFaceFW/fineweb-edu", 
            "split": "train", 
            "subset": "sample-10BT",
            "schemas": [
                {
                    "text": ""
                }
            ],
            "delimiter": "\n",
            "buffer_size": 1000,
            "sample_rate":1.0
        },
#         {
#             "hf": True,
#             "repo": "cerebras/SlimPajama-627B",
#             "split": "train",
#             "val_split": "validation",
#             "val_samples": 1000,
# #             "snapshots": [
# #                 "2023-14"
# #             ],
# #             "subset": "sample-10B",
# #             "languages": [
# #                 "en"
# #             ],
#             "schemas": [
#                 {
#                     "text": ""
#                 }
#             ],
#             "buffer_size": 1000,
#             "sample_rate": 1.0
#         },
#         {
#             "hf": True,
#             "repo": "togethercomputer/RedPajama-Data-V2",
#             "split": "train",
#             "snapshots": [
#                 "2023-14"
#             ],
#             "subset": "sample-10B",
#             "languages": [
#                 "en"
#             ],
#             "schemas": [
#                 {
#                     "raw_content": ""
#                 }
#             ],
#             "buffer_size": 1000,
#             "sample_rate": 1.0
#         },
        {
            "hf": True,
            "repo": "Muennighoff/natural-instructions",
            "split": "train",
            "val_split": "test",
            "schemas": [
                {
                   "definition": "¶{context}:> ",
                   "inputs": '¶{human}:> ',
                   "targets": '¶{robot}:> '
                },
                {
                   "definition": "SYSTEM: ",
                   "inputs": 'USER: ',
                   "targets": 'ASSISTANT: '
                },
                {
                   "definition": "CONTEXT: ",
                   "inputs": 'INPUT: ',
                   "targets": 'OUTPUT: '
                }
            ],
            "patterns": [
                '{context}',
                '{human}',
                '{robot}'
            ],
            "delimiter": "\n",
            "buffer_size": 1000,
            "val_samples": 1000,
            "sample_rate": 0.25,
        },
        {
            "hf": True,
            "repo": "databricks/databricks-dolly-15k",
            "split": "train",
            "schemas": [
                {
                   "context": "¶{context}:> ",
                   "instruction": '¶{instruction}:> ',
                   "response": '¶{response}:> '
                },
                {
                   "context": "SYSTEM: ",
                   "instruction": 'USER: ',
                   "response": 'ASSISTANT: '
                },
                {
                   "context": "CONTEXT: ",
                   "instruction": 'INPUT: ',
                   "response": 'OUTPUT: '
                }
            ],
            "patterns": [
               '{context}',
               '{instruction}',
               '{response}'
            ],
            "delimiter": "\n",
            "buffer_size": 1000,
            "sample_rate": 0.25,
        },
        {
            "hf": True,
            "repo": "HuggingFaceTB/smollm-corpus",
            "split": "train",
            "subset": "cosmopedia-v2",
            "schemas": [
                {
                   "prompt": '¶{prompt}:> ',
                   "text": '¶{text}:> '
                },
                {
                   "prompt": 'USER: ',
                   "text": 'ASSISTANT: '
                },
                {
                   "prompt": 'INPUT: ',
                   "text": 'OUTPUT: '
                }
            ],
            "patterns": [
               '{prompt}',
               '{text}'
            ],
            "delimiter": "\n",
            "buffer_size": 1000,
            "sample_rate": 0.5,
        },
        {
            "hf": True,
            "repo": "open-phi/textbooks",
            "split": "train",
            "schemas": [
                {
                   "markdown": '',
                }
            ],
            "delimiter": "\n",
            "buffer_size": 1000,
            "sample_rate": 1.0,
        },
        {
            "hf": True,
            "repo": "roneneldan/TinyStories",
            "split": "train",
            "subset": "default",
            "schemas": [
                {
                    "text": '',
                },
                {
                    "text": ': ',
                },
                {
                    "text": ':> ',
                },
                {
                    "text": 'OUTPUT: ',
                },
            ],
            "delimiter": "\n",
            "buffer_size": 1000,
            "sample_rate": 0.25,
            "val_split": "validation",
            "val_samples": 1000,
        },
    ],
    batch_size=2,
    gradient_accumulation_steps=8,
    block_size=2048,
    num_steps=20000,
    val_interval=1000,
    warmup_steps=10,
    optimizer="Lion",
    learning_rate=0.0001,
    weight_decay=0.001,
    gradient_clip_val=1.0,
    scheduler="cosine",
    loggers=[logger],
    gradient_checkpointing=True,
    generate_every=10,
    save_every=25,
    checkpoint_every=25,
    resume=resume_training,
    progress_bar=True,
    output_dir=f"{output_dir}/{focus}",
)

## Testing

For testing, we just run an interactive inference session.

In [None]:
# Test inference

while True:
    print("PROMPT:\n")
    prompt = input()
    completion = prototype.generate(
        prompt=prompt,
        do_sample=True,
        min_length=23,
        max_new_tokens=111,
        temperature=0.9,
        eta_cutoff=0.0003,
        penalty_alpha=0.6,
        top_k=4,
        repetition_penalty=1.1,
        no_repeat_ngram_size=13,
        renormalize_logits=True,
        remove_invalid_values=True,
        max_time=60,
        use_cache=True,
    )
    print("COMPLETION:\n")
    print(completion)