torinriley commited on
Commit
3c8aa4a
·
1 Parent(s): 8468281
Files changed (3) hide show
  1. data/streaming_dataset.py +63 -0
  2. model.py +22 -2
  3. train.py +43 -17
data/streaming_dataset.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tiktoken
4
+ from datasets import load_dataset, concatenate_datasets, interleave_datasets
5
+ from torch.utils.data import IterableDataset
6
+ import torch
7
+
8
+ class StreamingDataset(IterableDataset):
9
+ """Streaming dataset that loads and processes data on the fly"""
10
+
11
+ def __init__(self, dataset_configs, block_size=2048, batch_size=12):
12
+ self.dataset_configs = dataset_configs
13
+ self.block_size = block_size
14
+ self.batch_size = batch_size
15
+ self.enc = tiktoken.get_encoding("gpt2")
16
+
17
+ def load_and_process_chunk(self, dataset_name, split="train"):
18
+ # Load datasets with appropriate configs
19
+ if dataset_name == "openwebtext":
20
+ dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
21
+ elif dataset_name == "the_pile":
22
+ dataset = load_dataset("the_pile", split=split, streaming=True)
23
+ elif dataset_name == "red_pajama":
24
+ dataset = load_dataset("togethercomputer/RedPajama-Data-1T", split=split, streaming=True)
25
+
26
+ for example in dataset:
27
+ ids = self.enc.encode_ordinary(example['text'])
28
+ ids.append(self.enc.eot_token)
29
+ if len(ids) >= self.block_size:
30
+ # Return chunks of block_size
31
+ for i in range(0, len(ids) - self.block_size + 1, self.block_size):
32
+ yield torch.tensor(ids[i:i + self.block_size])
33
+
34
+ def __iter__(self):
35
+ # Interleave datasets with specified weights
36
+ iterators = []
37
+ weights = []
38
+ for config in self.dataset_configs:
39
+ iterators.append(self.load_and_process_chunk(config['name']))
40
+ weights.append(config['weight'])
41
+
42
+ # Normalize weights
43
+ weights = np.array(weights) / sum(weights)
44
+
45
+ while True:
46
+ # Randomly select a dataset based on weights
47
+ dataset_idx = np.random.choice(len(iterators), p=weights)
48
+ try:
49
+ batch = []
50
+ for _ in range(self.batch_size):
51
+ batch.append(next(iterators[dataset_idx]))
52
+ yield torch.stack(batch)
53
+ except StopIteration:
54
+ # Restart iterator if it's exhausted
55
+ iterators[dataset_idx] = self.load_and_process_chunk(self.dataset_configs[dataset_idx]['name'])
56
+ continue
57
+
58
+ # Example usage:
59
+ dataset_configs = [
60
+ {'name': 'openwebtext', 'weight': 0.4},
61
+ {'name': 'the_pile', 'weight': 0.3},
62
+ {'name': 'red_pajama', 'weight': 0.3}
63
+ ]
model.py CHANGED
@@ -114,6 +114,7 @@ class GPTConfig:
114
  n_embd: int = 768
115
  dropout: float = 0.0
116
  bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
 
117
 
118
  class GPT(nn.Module):
119
 
@@ -144,6 +145,9 @@ class GPT(nn.Module):
144
  if pn.endswith('c_proj.weight'):
145
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146
 
 
 
 
147
  # report number of parameters
148
  print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149
 
@@ -177,8 +181,24 @@ class GPT(nn.Module):
177
  tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178
  pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179
  x = self.transformer.drop(tok_emb + pos_emb)
180
- for block in self.transformer.h:
181
- x = block(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  x = self.transformer.ln_f(x)
183
 
184
  if targets is not None:
 
114
  n_embd: int = 768
115
  dropout: float = 0.0
116
  bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117
+ gradient_checkpointing: bool = False # Enable gradient checkpointing for memory efficiency
118
 
119
  class GPT(nn.Module):
120
 
 
145
  if pn.endswith('c_proj.weight'):
146
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
147
 
148
+ # Enable gradient checkpointing if configured
149
+ self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
150
+
151
  # report number of parameters
152
  print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
153
 
 
181
  tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
182
  pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
183
  x = self.transformer.drop(tok_emb + pos_emb)
184
+
185
+ if self.gradient_checkpointing and self.training:
186
+ # Use gradient checkpointing for transformer layers
187
+ def create_custom_forward(module):
188
+ def custom_forward(*args):
189
+ return module(*args)
190
+ return custom_forward
191
+
192
+ x = torch.utils.checkpoint.checkpoint_sequential(
193
+ self.transformer.h,
194
+ len(self.transformer.h),
195
+ create_custom_forward(self.transformer.h[0]),
196
+ x
197
+ )
198
+ else:
199
+ for block in self.transformer.h:
200
+ x = block(x)
201
+
202
  x = self.transformer.ln_f(x)
203
 
204
  if targets is not None:
train.py CHANGED
@@ -47,13 +47,13 @@ wandb_run_name = 'gpt2' # 'run' + str(time.time())
47
  dataset = 'openwebtext'
48
  gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
49
  batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
50
- block_size = 1024
51
- # model
52
- n_layer = 12
53
- n_head = 12
54
- n_embd = 768
55
- dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
56
- bias = False # do we use bias inside LayerNorm and Linear layers?
57
  # adamw optimizer
58
  learning_rate = 6e-4 # max learning rate
59
  max_iters = 600000 # total number of training iterations
@@ -70,8 +70,11 @@ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchi
70
  backend = 'nccl' # 'nccl', 'gloo', etc.
71
  # system
72
  device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
73
- dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
74
  compile = True # use PyTorch 2.0 to compile the model to be faster
 
 
 
75
  # -----------------------------------------------------------------------------
76
  config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
77
  exec(open('configurator.py').read()) # overrides from command line or config file
@@ -111,20 +114,43 @@ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.aut
111
  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
112
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
113
 
114
- # poor man's data loader
115
- data_dir = os.path.join('data', dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def get_batch(split):
117
- # We recreate np.memmap every batch to avoid a memory leak, as per
118
- # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
119
  if split == 'train':
120
- data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
 
 
 
 
 
 
 
 
121
  else:
 
122
  data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
123
- ix = torch.randint(len(data) - block_size, (batch_size,))
124
- x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
125
- y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
 
126
  if device_type == 'cuda':
127
- # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
128
  x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
129
  else:
130
  x, y = x.to(device), y.to(device)
 
47
  dataset = 'openwebtext'
48
  gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
49
  batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
50
+ block_size = 2048 # increased context length
51
+ # model (1.3B parameters)
52
+ n_layer = 24 # scaled up from 12
53
+ n_head = 16 # scaled up from 12
54
+ n_embd = 1024 # scaled up from 768
55
+ dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
56
+ bias = False # do we use bias inside LayerNorm and Linear layers?
57
  # adamw optimizer
58
  learning_rate = 6e-4 # max learning rate
59
  max_iters = 600000 # total number of training iterations
 
70
  backend = 'nccl' # 'nccl', 'gloo', etc.
71
  # system
72
  device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
73
+ dtype = 'float16' # use fp16 training with gradient scaling
74
  compile = True # use PyTorch 2.0 to compile the model to be faster
75
+ # mixed precision and memory optimization
76
+ use_amp = True # use automatic mixed precision (fp16)
77
+ gradient_checkpointing = True # trade compute for memory
78
  # -----------------------------------------------------------------------------
79
  config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
80
  exec(open('configurator.py').read()) # overrides from command line or config file
 
114
  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
115
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
116
 
117
+ # streaming data loader
118
+ from data.streaming_dataset import StreamingDataset
119
+
120
+ dataset_configs = [
121
+ {'name': 'openwebtext', 'weight': 0.4},
122
+ {'name': 'the_pile', 'weight': 0.3},
123
+ {'name': 'red_pajama', 'weight': 0.3}
124
+ ]
125
+
126
+ train_dataset = StreamingDataset(dataset_configs, block_size=block_size, batch_size=batch_size)
127
+ train_loader = torch.utils.data.DataLoader(
128
+ train_dataset,
129
+ batch_size=None, # batch size is handled by the dataset
130
+ num_workers=4,
131
+ pin_memory=True
132
+ )
133
+ train_iter = iter(train_loader)
134
+
135
  def get_batch(split):
 
 
136
  if split == 'train':
137
+ try:
138
+ batch = next(train_iter)
139
+ except StopIteration:
140
+ # Reset iterator when exhausted
141
+ train_iter = iter(train_loader)
142
+ batch = next(train_iter)
143
+
144
+ x = batch[:, :-1] # all but last token
145
+ y = batch[:, 1:] # all but first token
146
  else:
147
+ # For validation, we'll keep using the original approach with memmap files
148
  data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
149
+ ix = torch.randint(len(data) - block_size, (batch_size,))
150
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
151
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
152
+
153
  if device_type == 'cuda':
 
154
  x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
155
  else:
156
  x, y = x.to(device), y.to(device)