In [1]:
!pip install torchtune
!pip install torchao
!pip install wandb


import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import tqdm
from dataclasses import dataclass
from torchtune.modules import RMSNorm
from tokenizers import Tokenizer
from pathlib import Path
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import wandb
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets



In [2]:
!nvidia-smi

Sun Apr 27 03:34:00 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   31C    P0             45W /  400W |       5MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
import wandb
from google.colab import userdata

wandb.login(key=userdata.get('WANDB'))

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mlaampt[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
import os

def setup(rank=None, world_size=None):
    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '12355'
    init_process_group("nccl")

def cleanup():
    destroy_process_group()

In [5]:
from pathlib import Path
data_path = Path('data')
data_path.mkdir(exist_ok=True)
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!cp input.txt data/input.txt

--2025-04-27 03:34:09--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-04-27 03:34:09 (167 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [6]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-27b-it", token=userdata.get('HF_TOKEN'))
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

1

In [7]:
tokenizer.vocab_size

262144

In [8]:
@dataclass
class ModelArgs:
    #Hyperparameters

    block_size = 256
    batch_size = 64
    embeddings_dims = 512
    attn_dropout = 0.1
    no_of_heads = 8 #IMP needs to be thoroughly calculated
    dropout = 0.1
    epochs = 100
    max_lr = 2.5e-4
    no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated
    weight_decay_optim = 0.1
    beta_1 = 0.9
    beta_2 = 0.95
    device = 'cuda:0'
    no_kv_heads = 2
    scaling_factor = 0.5
    vocab_size = len(tokenizer.get_vocab()) + 768
    local_block_size = 128
    base_freq=10000

In [9]:
#Datasets

# Using tinyshakespeare

with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [10]:
def save_checkpoint(model):
    ckp = model.module.state_dict()
    torch.save(ckp, "checkpoint.pt")
    print("Checkpoint saved")

In [11]:
#Subword level tokenization

#Loading custom trained BPE
# Load the tokenizer
# tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json")
# vocab_size = tokenizer.get_vocab_size()
# Encode and decode functions
# encode = lambda s: tokenizer.encode(s).ids
# decode = lambda l: tokenizer.decode(l)





###############################################################################
#Character level tokenization

# # here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)


# create a mapping from characters to integers
stoi = { ch: i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
    x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
    y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
    x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
    return x, y

In [12]:
tinystories = True
fw = False
fw_train = None
fw_test = None
if(tinystories):

    fw_train = load_dataset("roneneldan/TinyStories", split="train")
    fw_test = load_dataset("roneneldan/TinyStories", split="validation")
    print(fw_train)
    print(fw_test)
if(fw):
    fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
    fw_train = fw_train.train_test_split(test_size=0.01)
    print(fw_train)
    print(fw_train)

README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

Dataset({
    features: ['text'],
    num_rows: 2119719
})
Dataset({
    features: ['text'],
    num_rows: 21990
})


In [13]:
def prepare_dataset(split, device, batch_size):
    print("Device is: ", device)

    def collate_fn(batch):
        # Extract text data
        texts = [item ["text"] for item in batch]

        input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")

        input_encodings["labels"] = input_encodings["input_ids"].clone()  # Use `input_ids` as labels

        input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:]  # Shift right
        input_encodings["labels"][:, -1] = tokenizer.eos_token_id  # Let the last token be end

        return input_encodings


    dataloader = None
    if(tinystories):
        if(split == 'train'):
            data_loader = DataLoader(
            fw_train,
            # generator=generator,
            batch_size=batch_size,

            # sampler=DistributedSampler(fw_train, shuffle=True),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=False
        )
        elif(split == 'val'):
            data_loader = DataLoader(
            fw_test,


            batch_size=batch_size,
            # sampler=DistributedSampler(fw_test, shuffle=True),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=False
        )
    elif(fw):
        if(split == 'train'):
            data_loader = DataLoader(
            fw_train['train'],
            batch_size=batch_size,


            sampler=DistributedSampler(fw_train['train'], shuffle=True),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=False
    )
        elif(split == 'val'):
            data_loader = DataLoader(
            fw_train['test'],
            batch_size=batch_size,
                # generator=generator,
            sampler=DistributedSampler(fw_train["test"]),
            collate_fn=collate_fn,

            drop_last=True,
            shuffle=False
        )
    return data_loader

In [14]:
# from andrej karapathy github
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    generated_tokens = []
    ModelArgs.inference=True
    for _ in range(max_length):
        with torch.no_grad(), torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16):
            outputs = model(input_ids)
            logits = outputs[:, -1, :]

            probs = F.softmax(logits, dim=-1)

            # Top-k filtering
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)


            # Apply temperature scaling
            probs = probs / temperature

            # Sample from top-k
            next_token = torch.multinomial(top_k_probs, num_samples=1)


            # generated_tokens.append(next_token.item())

            xcol = torch.gather(top_k_indices, -1, next_token)
            # generated_tokens.append(xcol)
            input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

In [15]:
class Normalization(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        super().__init__()
        self.rmsnorm_layer = RMSNorm(dim=embeddings_dims)


    def forward(self, x):

        x = self.rmsnorm_layer(x)
        return x

In [16]:
# import numpy as np
class RotaryEmbeddings(nn.Module):
    def __init__(
        self,
         device,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        batch_size: int = ModelArgs.batch_size,
        scaling_factor: float = 0.5,
    ):
        super().__init__()

        self.embeddings_dims = embeddings_dims
        self.block_size = block_size
        self.batch_size = batch_size
        self.scaling_factor = scaling_factor
        self.theta = 0
        self.device=device

    def apply_rope(self, seq, base_freq):
        batch_size, seq_len, embeds_dims = seq.shape
        token_indices = torch.arange(0 , seq_len, dtype=torch.float32,  device = self.device).unsqueeze(1)
        positions = torch.arange(0 , self.embeddings_dims, 2, dtype=torch.float32,  device = self.device).unsqueeze(0)
        theta = base_freq ** (-2 * (positions * self.scaling_factor) / self.embeddings_dims) #Position Interpolation
        angles = token_indices * theta
        angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
        x_reshaped = seq.view(batch_size, seq_len, self.embeddings_dims // 2, 2)

        cos_angles = torch.cos(angles)
        sin_angles = torch.sin(angles)


        out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=1)
        out = out.view(batch_size, seq_len, embeds_dims)
        return out

    def forward(self, x, base_freq):

        res = self.apply_rope(x,base_freq=base_freq)
        return res

In [17]:
class MQA(nn.Module):
    def __init__(
        self,
        device,
        no_of_q_heads: int,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,


    ):
        super().__init__()


        # self.no_of_q_heads = no_of_heads // no_of_kv_heads
        # self.no_of_q_heads = no_of_q_heads
        self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads
        self.head_size = embeddings_dims // no_of_q_heads
        # self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
        self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size,  device = device)
        # self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size,  device = device)
        # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  bias=False)
        self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  dtype=torch.float32, bias=False,  device = device)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  dtype=torch.float32, bias=False,  device = device)
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims,  dtype=torch.float32, bias=False,  device = device)
        self.device = device
        self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size,  bias=False,  device = self.device) for _ in range(self.no_of_kv_heads)])

    def scaled_dot_product(self, q, k, v, block_size, base_freq):

            # masked = torch.tril(torch.ones((block_size, block_size),  requires_grad=False,  device = self.device))
            normalized_q = q * (torch.norm(q, p=2)** -1)
            q = self.rotary(normalized_q, base_freq)
            masked_table = torch.tril(torch.ones((block_size, block_size),  requires_grad=False,  device = self.device))
            # rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
            # rotary_key = matrix @ k.permute(1,2,0)  #  (B,T, C,C  ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
            # print("Query: ", q.shape)
            # print("Keys: ", k.shape)
            # print(q.permute(2,0,1).shape)
            # print(k.permute(2,0,1).transpose(-2, -1).shape)
            # weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
            # weights = q @ k.permute(2,1,0)
            # print(weights.shape)
            # print(masked.shape)
            weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
            masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
            weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
            weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            return out

    def forward(self,x, base_freq=10000):
        # print("MQA: ", x.shape)
        batch, block_size, embeddings_dims = x.shape

        # query = self.query(x)
        # matrix = self.rotary_matrix(block_size)


        key = self.key(x)
        key_normalized = key * (torch.norm(key, p=2)** -1)
        values = self.value(x)
        # print("Keys: ", key.shape)
        # print("Values: ", values.shape)
        # rotary_value = self.rotary(values)
        rotary_key = self.rotary(key_normalized, base_freq)
        multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size, base_freq) for query in self.multi_query], dim=-1)
        # print("Multi query: ", multi_query_concat.shape)

        linear_layer= self.linear_layer(multi_query_concat)
        # out = self.dropout(linear_layer)
        return linear_layer

In [18]:
class GQA(nn.Module):
    def __init__(
        self,
         device,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        # no_of_q_heads: int = ModelArgs.no_of_heads,
        mqa_heads: int = ModelArgs.no_kv_heads
    ):
        super().__init__()

        # self.no_of_kv_heads = no_of_kv_heads
        self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
        # self.head_dim = embeddings_dims // self.no_kv_heads
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32,  bias=False,  device = device)
        self.device = device
        self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
        # self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
    def forward(self,x, base_freq):

        batch, block_size, embeddings_dims = x.shape

        # res = self.mqa(x)
        grouped_query_concat = torch.cat([group(x, base_freq) for group in self.mqa], dim=-1)

        linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
        out = self.dropout(linear_layer)
        return out

In [19]:
class Swish(nn.Module):
    def __init__(
        self,
        device,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        super().__init__()

        self.sig = torch.nn.Sigmoid()


    def forward(self, x):
        swish = x * self.sig(x)

        return swish

In [20]:
class SWiGLU(nn.Module):
    def __init__(
        self,
        device,
        block_size: int = ModelArgs.block_size,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        super().__init__()
        self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
        self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
        self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims,  bias=False, dtype=torch.float32,  device = device)
        self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims,  bias=False, dtype=torch.float32,  device = device)
        self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims,  bias=False, dtype=torch.float32,  device = device)




    def forward(self, x):
        swish_res = self.swish(self.linear_layer1(x))
        x_V = self.linear_layer2(x)
        res = torch.mul(swish_res, x_V)
        out = self.linear_layer3(res)
        return out

In [21]:
class FFN(nn.Module):
    def __init__(self,
                  device,
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                   dropout = ModelArgs.dropout

                 ):
        super().__init__()

        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims,  dtype=torch.float32,  device = device)
        self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims,  device = device)
        # self.dropout = nn.Dropout(p = dropout)
    def forward(self, x):

        x = self.swiglue(x)
        x = self.linear_layer(x)
        # x = self.dropout(x)
        return x

In [22]:
class DecoderLayer(nn.Module):
    def __init__(self,
                device,
                embeddings_dims: int = ModelArgs.embeddings_dims,
                dropout = ModelArgs.dropout,
                block_size: int = ModelArgs.block_size,
                vocab_size: int = ModelArgs.vocab_size,

                 ) :
        super().__init__()

        # self.base_freq = ModelArgs.base_freq
        self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size,  device = device)
        self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2,  device = device)
        # self.norm = Normalization(embeddings_dims=embeddings_dims)
        self.norm1 = Normalization(embeddings_dims=embeddings_dims)
        self.norm2 = Normalization(embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = dropout)
    def forward(self, x, base_freq):

        x = x + self.gqa(self.norm1(x), base_freq)
        x = x + self.feedforward_network(self.norm2(x))
        return x

In [23]:
class Gemma(nn.Module):
    def __init__(self,
                    device,
                  embeddings_dims: int = ModelArgs.embeddings_dims,
                  no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
                  block_size: int = ModelArgs.block_size,
                  vocab_size: int = ModelArgs.vocab_size,
                  dropout = ModelArgs.dropout

                 ) :
        super().__init__()
        self.base_freq = ModelArgs.base_freq
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims,  dtype=torch.float32,  device = device)
        self.decoder = nn.ModuleList(DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout,  device = device) for _ in range(no_of_decoder_layers))
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size,  dtype=torch.float32,  device = device)
        self.dropout = nn.Dropout(p = dropout)
        self.norm = Normalization(embeddings_dims)


        #weight tying
        # self.embeddings.weight = self.linear_layer.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)

                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, x):
        global_base_freq = 100000
        local_base_freq = 10000
        index = 0
        no_of_layers = 0
        x = self.embeddings(x)
        x = self.dropout(x)
        temp = x.clone()
        # x = self.decoder(x)
        for layer in self.decoder:
            if no_of_layers % 5 == 0:
                x = layer(x, global_base_freq)
                # print("x shape: ", x.shape)
            else:

                local_block = temp[:, : index + ModelArgs.local_block_size, :]
                x = layer(local_block, local_base_freq)
                index += ModelArgs.local_block_size
                # print("x shape local: ", x.shape)
            no_of_layers += 1
        # print(x.shape)
        x = self.norm(x)
        x = self.linear_layer(x)

        return x

In [24]:
# Instantiating the model
# device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# ModelArgs.device = device
model = Gemma(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=ModelArgs.device)
model = model.to(ModelArgs.device)

# model = DDP(model, device_ids=[gpu_ids])

In [25]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [26]:
#Printing a summary of the architecture
from torchinfo import summary
idx, targets = get_batch('test')
idx = idx.to(ModelArgs.device)
summary(model=model,
        input_data=idx,
        # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
Gemma (Gemma)                                                [64, 256]            [64, 256, 262914]    --                   True
├─Embedding (embeddings)                                     [64, 256]            [64, 256, 512]       134,611,968          True
├─Dropout (dropout)                                          [64, 256, 512]       [64, 256, 512]       --                   --
├─ModuleList (decoder)                                       --                   --                   --                   True
│    └─DecoderLayer (0)                                      [64, 256, 512]       [64, 256, 512]       --                   True
│    │    └─Normalization (norm1)                            [64, 256, 512]       [64, 256, 512]       512                  True
│    │    └─GQA (gqa)                                        [64, 256, 512]       [64, 256, 51

In [27]:
# import tqdm
def train():
    # Set device to CUDA if available
    device = ModelArgs.device
    print(f"Start running training on {device}.")

    # Initialize wandb for experiment tracking
    wandb.init(
        project = 'Gemma-Training',
        # config = ModelArgs, # you can uncomment this to log model config
    )

    # Create model and move to GPU
    model = Gemma(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size,
                  vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
    model = model.to(device)

    print("Model loaded")
    # Setup optimizer
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr)

    # Training parameters
    save_checkpoint_iter = 500
    total_iters = 25000
    eval_iters = 500


    # Training progress bar
    train_epoch_iterator = tqdm.tqdm(range(total_iters), desc="Training")
    val_dataloader = prepare_dataset('val', device, ModelArgs.batch_size)
    val_iterator = iter(val_dataloader)
    # Get batches for training
    @torch.inference_mode()
    def estimate_loss():
        out = {}
        model.eval()
        count = 0
        for split in ['val']:
            print(f"Starting with {split} evaluation...")
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):

                nonlocal val_iterator

                # for k, batch in enumerate(dataloader):
                try:
                    batch = next(val_iterator)
                except StopIteration:
                    val_iterator = iter(val_dataloader)
                    batch = next(val_iterator)

                input_ids = batch["input_ids"].to(device)
                targets = batch["labels"].to(device)

                logits = model(input_ids)
                batch_size, block_size, embeddings_dims = logits.shape
                logits = logits.view(batch_size*block_size, embeddings_dims)
                targets = targets.view(batch_size * block_size)
                loss = nn.functional.cross_entropy(logits, targets)
                losses[k] = loss.item()
                # count += 1
            out[split] = losses.mean()

        model.train()
        return out
    token_count = 0
    # Start training loop
    model.train()
    print("Lessgoo...")
    dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
    train_dataloader = iter(dataloader)
    accumulated_loss = 0.0
    for step in train_epoch_iterator:
        # Periodically evaluate loss on train and val sets
        if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
            losses = estimate_loss()
            avg_val_loss = torch.Tensor([losses['val']]).to(device)
            print(f"step {step}: train loss {accumulated_loss:.4f}, val loss {losses['val']:.4f}")
            val_perplexity = torch.exp(torch.tensor(avg_val_loss)).item()
            # Log metrics to wandb
            wandb.log({
                "val_perplexity": val_perplexity,
                # "val_step_loss": losses['train'],
                "val_step_loss": losses['val'],
                "step": step
            })

        # Save checkpoint periodically
        if step % save_checkpoint_iter == 0 and step != 0:
            print(f"Saving the model checkpoint for step: {step}")
            torch.save(model.state_dict(), "checkpoint.pt")
            print("Checkpoint saved")

        # Get batch for training step
        try:
            batch = next(train_dataloader)
        except StopIteration:
            train_dataloader = iter(dataloader)
            batch = next(train_dataloader)

        # for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        targets = batch["labels"].to(device)

        # Forward pass
        logits = model(input_ids)
        batch_size, block_size, embeddings_dims = logits.shape
        logits = logits.view(batch_size*block_size, embeddings_dims)
        targets = targets.view(batch_size * block_size)
        loss = nn.functional.cross_entropy(logits, targets)

        token_count += (len(input_ids) * ModelArgs.batch_size)

        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        accumulated_loss = loss.item()
        perplexity = torch.exp(torch.tensor(accumulated_loss)).item()  # Calculate perplexity
        # if(device == 0):
        wandb.log({
                    # "Learning Rate": scheduler.get_last_lr()[0],
                    "Train_Loss": accumulated_loss,
                    # "Train loss": loss.item(),
                    "Train Perplexity": perplexity,
                    "Total Tokens Processed": token_count,
                    "Step": step,
                    # "Gradient Norm": total_norm_before.item(),
                    # "Epoch": epoch

        })

        if(step % eval_iters == 0):
                prompt = "Once upon a time "
                generated_text = topk_sampling(model, prompt, max_length=ModelArgs.block_size, top_k=50, temperature=1.0, device=device)


                print(f" Step: {step} | Generated Text: {generated_text}")

    # Finish wandb run
    wandb.finish()

# Print CUDA device count but won't be using DDP
world_size = torch.cuda.device_count()
print(f"CUDA devices available: {world_size}")
train()

CUDA devices available: 1
Start running training on cuda:0.


Model loaded


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Device is:  cuda:0
Lessgoo...
Device is:  cuda:0


Training:   0%|          | 0/25000 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 16.05 GiB. GPU 0 has a total capacity of 39.56 GiB of which 5.13 GiB is free. Process 26277 has 34.42 GiB memory in use. Of the allocated memory 21.05 GiB is allocated by PyTorch, and 12.88 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)