crpatel commited on
Commit
ab3efd5
·
1 Parent(s): fb995a8

gradio app

Browse files
Files changed (6) hide show
  1. app.py +177 -0
  2. config_smollm2_135M.yaml +108 -0
  3. deepseek_v3.py +459 -0
  4. requirements.txt +14 -0
  5. train.py +417 -0
  6. utils.py +182 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ import yaml
5
+ from deepseek_v3 import DeepSeekV3Model
6
+ import os
7
+
8
+
9
+ def generate_helper(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
10
+
11
+ model = model.to(device)
12
+ idx = idx.to(device)
13
+ model.eval()
14
+ for _ in range(max_new_tokens):
15
+ idx_cond = idx[:, -context_length:]
16
+ with torch.no_grad():
17
+ logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
18
+ logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
19
+
20
+ # Get the logits for the last token only
21
+ logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
22
+
23
+ if top_k is not None:
24
+ # top k sampling
25
+ top_logits, top_pos = torch.topk(logits, top_k)
26
+ min_logit = top_logits[:, -1].unsqueeze(-1)
27
+ logits = torch.where(logits < min_logit,
28
+ torch.tensor(float('-inf')).to(logits.device),
29
+ logits)
30
+
31
+ # temperature scaling
32
+ if temperature > 0.0:
33
+ logits /= temperature
34
+ probs = torch.softmax(logits, dim=-1)
35
+ idx_next = torch.multinomial(probs, num_samples=1)
36
+ else:
37
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
38
+
39
+ if idx_next.item() == eos_token:
40
+ break
41
+
42
+ idx = torch.cat((idx, idx_next), dim=1)
43
+ model.train()
44
+ return idx
45
+
46
+ def get_config(config_path):
47
+ config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
48
+ return config
49
+
50
+ def extract_and_save_weights(config_path, checkpoint_path, weights_path, device):
51
+ """Extract model weights from checkpoint and save as a separate .pt file"""
52
+ print(f"Extracting weights from checkpoint: {checkpoint_path}")
53
+ config = get_config(config_path)
54
+ model = DeepSeekV3Model(config['model'])
55
+
56
+ # Load checkpoint
57
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
58
+ state_dict = checkpoint['model_state_dict']
59
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
60
+
61
+ # Save just the model weights
62
+ torch.save(state_dict, weights_path)
63
+ print(f"Model weights saved to: {weights_path}")
64
+ return state_dict
65
+
66
+ def load_weights(config, weights_path, device):
67
+ """Load model from weights file"""
68
+ print(f"Loading model from weights: {weights_path}")
69
+ model = DeepSeekV3Model(config['model'])
70
+ state_dict = torch.load(weights_path, map_location=torch.device(device))
71
+ model.load_state_dict(state_dict)
72
+ return model
73
+
74
+ def get_tokenizer(config):
75
+ tokenizer_path = config['tokenizer']['tokenizer_name_or_path']
76
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
77
+ tokenizer.pad_token = tokenizer.eos_token
78
+ vocab_size = tokenizer.vocab_size
79
+ return tokenizer, vocab_size
80
+
81
+ def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device):
82
+ encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device)
83
+ generated_text = generate_helper(model,
84
+ idx=encoded_text,
85
+ max_new_tokens=max_new_tokens,
86
+ context_length=context_length,
87
+ temperature=temperature,
88
+ top_k=top_k,
89
+ eos_token=eos_token,
90
+ device=device)
91
+ return tokenizer.decode(generated_text.squeeze(0))
92
+
93
+
94
+
95
+ # Initialize model and tokenizer
96
+ def initialize_model():
97
+ config_path = "config_smollm2_135M.yaml"
98
+ # Use HF Hub or another external storage instead of local path
99
+ model_id = "crpatel/DeepSeek-V3-SmolLm2" # Replace with your actual model ID
100
+ weights_path = "model.pt"
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+
103
+ # Load configuration
104
+ config = get_config(config_path)
105
+
106
+ # Check if weights exist locally, otherwise download from HF Hub
107
+ if not os.path.exists(weights_path):
108
+ try:
109
+ from huggingface_hub import hf_hub_download
110
+ print(f"Downloading model weights from Hugging Face Hub: {model_id}")
111
+ weights_path = hf_hub_download(
112
+ repo_id=model_id,
113
+ filename="model.pt"
114
+ )
115
+ except Exception as e:
116
+ print(f"Error downloading weights: {e}")
117
+ print("Falling back to local checkpoint extraction if available")
118
+ checkpoint_path = "checkpoints/model_100000_step_avg_loss_4.61663.pth"
119
+ if os.path.exists(checkpoint_path):
120
+ extract_and_save_weights(config_path, checkpoint_path, weights_path, device)
121
+ else:
122
+ raise FileNotFoundError(f"Neither weights file nor checkpoint found. Please upload model to HF Hub first.")
123
+
124
+ # Load model from weights
125
+ model = load_weights(config, weights_path, device)
126
+ model.to(device)
127
+ model.eval()
128
+
129
+ # Load tokenizer
130
+ tokenizer, vocab_size = get_tokenizer(config)
131
+
132
+ return model, tokenizer, device
133
+
134
+ def generate_response(prompt, max_new_tokens):
135
+ generated_text = generate_text(
136
+ model=model,
137
+ tokenizer=tokenizer,
138
+ input_text=prompt,
139
+ max_new_tokens=max_new_tokens,
140
+ context_length=256,
141
+ temperature=0.9,
142
+ top_k=2,
143
+ eos_token=tokenizer.eos_token_id,
144
+ device=device
145
+ )
146
+ return generated_text
147
+
148
+ # Initialize global variables
149
+ model, tokenizer, device = initialize_model()
150
+
151
+ # Create Gradio interface
152
+ iface = gr.Interface(
153
+ fn=generate_response,
154
+ inputs=[
155
+ gr.Textbox(
156
+ lines=3,
157
+ placeholder="Enter your prompt here...",
158
+ label="Input Prompt"
159
+ ),
160
+ gr.Slider(
161
+ minimum=50,
162
+ maximum=256,
163
+ value=100,
164
+ step=10,
165
+ label="Max New Tokens"
166
+ )
167
+ ],
168
+ outputs=gr.Textbox(
169
+ lines=5,
170
+ label="Generated Text"
171
+ ),
172
+ title="DeepSeek-V3 Text Generator",
173
+ description="Enter a prompt and adjust the maximum number of tokens to generate text with DeepSeek-V3 SmolLM2 model."
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ iface.launch()
config_smollm2_135M.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints:
2
+ checkpoint_interval: 2000
3
+ checkpoints_path: checkpoints
4
+ checkpoints_path_is_shared_file_system: false
5
+ resume_checkpoint_path: null
6
+ save_final_state: false
7
+ save_initial_state: false
8
+ data_stages:
9
+ - data:
10
+ dataset:
11
+ dataset_folder:
12
+ - datasets/smollm2-corpus
13
+ dataset_weights:
14
+ - 1.0
15
+ num_loading_workers: 0
16
+ seed: 8
17
+ name: stable phase
18
+ start_training_step: 1
19
+ general:
20
+ benchmark_csv_path: null
21
+ consumed_train_samples: null
22
+ ignore_sanity_checks: true
23
+ project: smollm2
24
+ run: smollm2-135M
25
+ seed: 8
26
+ step: null
27
+ logging:
28
+ iteration_step_info_interval: 1
29
+ log_level: info
30
+ log_level_replica: info
31
+ model:
32
+ ddp_bucket_cap_mb: 25
33
+ dtype: bfloat16
34
+ init_method:
35
+ std: 0.041666666666666664
36
+ make_vocab_size_divisible_by: 1
37
+ model_config:
38
+ bos_token_id: 0
39
+ eos_token_id: 0
40
+ hidden_act: silu
41
+ hidden_size: 576
42
+ initializer_range: 0.041666666666666664
43
+ intermediate_size: 1536
44
+ is_llama_config: true
45
+ max_position_embeddings: 2048
46
+ num_attention_heads: 9
47
+ num_hidden_layers: 30
48
+ num_key_value_heads: 3
49
+ pad_token_id: null
50
+ pretraining_tp: 1
51
+ rms_norm_eps: 1.0e-05
52
+ rope_interleaved: false
53
+ rope_scaling: null
54
+ rope_theta: 10000.0
55
+ tie_word_embeddings: true
56
+ use_cache: true
57
+ vocab_size: 49152
58
+ s3_bucket: deepseek-v3-train-mar-2025
59
+ s3_checkpoint_folder: checkpoints
60
+ s3_log_folder: logs
61
+ s3_log_file_name: training.log
62
+ # deepseek
63
+ compression_ratio: 4
64
+ num_experts: 4
65
+ num_shared_experts: 1
66
+ top_k: 2
67
+ optimizer:
68
+ accumulate_grad_in_fp32: true
69
+ clip_grad: 1.0
70
+ learning_rate_scheduler:
71
+ learning_rate: 0.003
72
+ lr_decay_starting_step: 1600000
73
+ lr_decay_steps: 400000
74
+ lr_decay_style: linear
75
+ lr_warmup_steps: 2000
76
+ lr_warmup_style: linear
77
+ min_decay_lr: 0
78
+ optimizer_factory:
79
+ adam_beta1: 0.9
80
+ adam_beta2: 0.95
81
+ adam_eps: 1.0e-08
82
+ name: adamW
83
+ torch_adam_is_fused: true
84
+ weight_decay: 0.01
85
+ zero_stage: 0
86
+ parallelism:
87
+ dp: 64
88
+ expert_parallel_size: 1
89
+ pp: 1
90
+ pp_engine: 1f1b
91
+ recompute_layer: false
92
+ tp: 1
93
+ tp_linear_async_communication: true
94
+ tp_mode: REDUCE_SCATTER
95
+ tp_recompute_allgather: true
96
+ profiler: null
97
+ tokenizer:
98
+ tokenizer_max_length: null
99
+ tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
100
+ tokenizer_revision: null
101
+ tokens:
102
+ batch_accumulation_per_replica: 1
103
+ limit_test_batches: 0
104
+ limit_val_batches: 0
105
+ micro_batch_size: 8
106
+ sequence_length: 512
107
+ train_steps: 2000000
108
+ val_check_interval: 1000
deepseek_v3.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import SiLU
5
+ import yaml
6
+
7
+
8
+ def _init_weights(module, std=0.041666666666666664):
9
+ if isinstance(module, nn.Linear):
10
+ module.weight.data.normal_(mean=0.0, std=std)
11
+ elif isinstance(module, nn.Embedding):
12
+ module.weight.data.normal_(mean=0.0, std=std)
13
+
14
+ class RotaryPositionalEmbedding(nn.Module):
15
+ """
16
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L240
17
+ Rotary Positional Embedding (RoPE) for transformers Implemntation derived from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
+ """
19
+ def __init__(self, dim: int, theta: float = 10000.0):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.theta = theta
23
+
24
+ def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
25
+ """
26
+ Apply rotary positional embedding to the input tensor.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor of shape [B, T, H, D] or [B, T, D]
30
+ seq_len (int): Sequence length.
31
+
32
+ Returns:
33
+ torch.Tensor: Output tensor with rotary positional embeddings applied.
34
+ """
35
+ # Handle different input shapes
36
+ if len(x.shape) == 3:
37
+ B, T, D = x.shape
38
+ is_4d = False
39
+ else:
40
+ B, T, H, D = x.shape
41
+ is_4d = True
42
+
43
+ # For 3D tensors, we need to ensure D is even
44
+ if not is_4d and D % 2 != 0:
45
+ raise ValueError(f"Feature dimension {D} must be divisible by 2 for RoPE")
46
+
47
+ # Generate position indices
48
+ position = torch.arange(T, dtype=torch.float32, device=x.device).unsqueeze(-1)
49
+
50
+ # Generate frequencies
51
+ if is_4d:
52
+ # For 4D tensors, use the head dimension
53
+ freqs = torch.exp(
54
+ torch.arange(0, D, 2, dtype=torch.float32, device=x.device) *
55
+ -(torch.log(torch.tensor(self.theta)) / D)
56
+ )
57
+ else:
58
+ # For 3D tensors, use the full dimension
59
+ freqs = torch.exp(
60
+ torch.arange(0, D, 2, dtype=torch.float32, device=x.device) *
61
+ -(torch.log(torch.tensor(self.theta)) / D)
62
+ )
63
+
64
+ # Compute sinusoids
65
+ sinusoid = position * freqs
66
+ sin = torch.sin(sinusoid)
67
+ cos = torch.cos(sinusoid)
68
+
69
+ # Reshape sin and cos to match the input tensor's shape
70
+ if is_4d:
71
+ sin = sin.unsqueeze(0).unsqueeze(2) # Shape: (1, T, 1, D // 2)
72
+ cos = cos.unsqueeze(0).unsqueeze(2) # Shape: (1, T, 1, D // 2)
73
+ else:
74
+ sin = sin.unsqueeze(0) # Shape: (1, T, D // 2)
75
+ cos = cos.unsqueeze(0) # Shape: (1, T, D // 2)
76
+
77
+ # Apply rotary embeddings
78
+ x_rotated = x.clone()
79
+
80
+ if is_4d:
81
+ x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin
82
+ x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin
83
+ else:
84
+ x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin
85
+ x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin
86
+
87
+ return x_rotated
88
+
89
+ class MultiHeadLatentAttention(nn.Module):
90
+ def __init__(self, config):
91
+ super().__init__()
92
+ self.config = config
93
+ self.num_attention_heads = self.config['num_attention_heads']
94
+ self.hidden_size = self.config['hidden_size']
95
+ # Ensure the hidden size is divisible by the number of attention heads
96
+ if self.hidden_size % self.num_attention_heads != 0:
97
+ raise ValueError(
98
+ f"hidden_size ({self.hidden_size}) must be divisible by num_attention_heads ({self.num_attention_heads})"
99
+ )
100
+ self.head_dim = self.hidden_size // self.num_attention_heads
101
+ self.latent_dim = self.hidden_size // self.config['compression_ratio']
102
+
103
+ # Matrix is decomposed into D and U matrix
104
+ # Compression KV Projection Matrix
105
+ self.kv_proj_D = nn.Linear(self.hidden_size, self.latent_dim, bias=False)
106
+ # Compression Q Projection Matrix
107
+ self.q_proj_D = nn.Linear(self.hidden_size, self.latent_dim, bias=False)
108
+
109
+ # UnCompression k projection matrix
110
+ self.k_proj_U = nn.Linear(self.latent_dim, self.hidden_size//2, bias=False)
111
+ # UnCompression v projection matrix
112
+ self.v_proj_U = nn.Linear(self.latent_dim, self.hidden_size, bias=False)
113
+ # UnCompression Q projection matrix
114
+ self.q_proj_U = nn.Linear(self.latent_dim, self.hidden_size//2, bias=False)
115
+
116
+ # Rope Key Components, K is built from X and Q is build from q_proj_D
117
+ self.rope_k = nn.Linear(self.hidden_size, self.hidden_size//2, bias=False)
118
+ self.rope_q = nn.Linear(self.latent_dim, self.hidden_size//2, bias=False)
119
+ # output projection matrix
120
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
121
+
122
+ self.rotary_emb = RotaryPositionalEmbedding(self.hidden_size//2, self.config['rope_theta'])
123
+
124
+ def forward(self, x, attn_mask=None):
125
+ B, T, C = x.size() # Batch Size, Sequence Length, Hidden Size
126
+ # Compression KV Projection Matrix
127
+ kv_d = self.kv_proj_D(x) # [B, T, Latent Dim]
128
+ # Compression Q Projection Matrix
129
+ q_d = self.q_proj_D(x) # [B, T, Latent Dim]
130
+ # Uncompress KV & Q Projection Matrix
131
+ k_proj_2 = self.k_proj_U(kv_d) # [B, T, Hidden Size//2]
132
+ q_proj_2 = self.q_proj_U(q_d) # [B, T, Hidden Size//2]
133
+ v = self.v_proj_U(kv_d) # [B, T, Hidden Size]
134
+
135
+ # Rope components
136
+ k_rope_2 = self.rope_k(x) # [B, T, Hidden Size//2]
137
+ q_rope_2 = self.rope_q(q_d) # [B, T, Hidden Size//2]
138
+
139
+ # Apply ROPE to the rope components
140
+ k_rope_2 = self.rotary_emb(k_rope_2, T) # [B, T, Hidden Size//2]
141
+ q_rope_2 = self.rotary_emb(q_rope_2, T) # [B, T, Hidden Size//2]
142
+
143
+ # Reshape Components for Multi-Head Attention
144
+ k_proj_2 = k_proj_2.view(B, T, self.num_attention_heads, self.head_dim//2)
145
+ k_rope_2 = k_rope_2.view(B, T, self.num_attention_heads, self.head_dim//2)
146
+ q_proj_2 = q_proj_2.view(B, T, self.num_attention_heads, self.head_dim//2)
147
+ q_rope_2 = q_rope_2.view(B, T, self.num_attention_heads, self.head_dim//2)
148
+
149
+ # Concatenate Components
150
+ k = torch.cat((k_proj_2, k_rope_2), dim=-1) # [B, T, H, D]
151
+ q = torch.cat((q_proj_2, q_rope_2), dim=-1) # [B, T, H, D]
152
+ v = v.view(B, T, self.num_attention_heads, self.head_dim)
153
+
154
+ # Reshape Components for Multi-Head Attention
155
+ k = k.transpose(1, 2) # [B, H, T, D]
156
+ q = q.transpose(1, 2) # [B, H, T, D]
157
+ v = v.transpose(1, 2) # [B, H, T, D]
158
+
159
+ # Apply Scaled Dot-Product Attention
160
+ attn_out = F.scaled_dot_product_attention(q, k, v,
161
+ dropout_p=0.0,
162
+ is_causal=True,
163
+ attn_mask=attn_mask)
164
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C) # [B, T, C]
165
+ return self.o_proj(attn_out) # [B, T, C]
166
+
167
+ class DeepSeekExpertLayer(nn.Module):
168
+ def __init__(self, hidden_size, intermediate_size):
169
+ super().__init__()
170
+ self.hidden_size = hidden_size
171
+ self.intermediate_size = intermediate_size
172
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
173
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
174
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
175
+ self.act_fn = SiLU()
176
+
177
+ def forward(self, x):
178
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
179
+
180
+ class DeepSeekMOE(nn.Module):
181
+ """
182
+ A Mixture of Experts (MoE) layer that routes input through a set of expert layers.
183
+
184
+ This class implements a mixture of experts mechanism where a subset of experts is selected
185
+ for each input token based on learned routing logits. The output is a combination of the
186
+ shared experts and the routed experts, allowing for efficient computation and increased
187
+ model capacity.
188
+
189
+ Attributes:
190
+ hidden_size (int): The size of the hidden layer.
191
+ intermediate_size (int): The size of the intermediate layer.
192
+ num_experts (int): Total number of experts available.
193
+ num_shared_experts (int): Number of shared experts that are used for all inputs.
194
+ top_k (int): The number of top experts to route each input to.
195
+ shared_experts (nn.ModuleList): List of shared expert layers.
196
+ routed_experts (nn.ModuleList): List of routed expert layers.
197
+ routing_fn (nn.Linear): Linear layer for computing routing logits.
198
+ routing_bias (nn.Parameter): Bias for the routing logits.
199
+
200
+ Methods:
201
+ forward(x): Forward pass through the MoE layer, routing input through selected experts.
202
+ """
203
+ def __init__(self, hidden_size, intermediate_size, num_experts, num_shared_experts, top_k):
204
+ super().__init__()
205
+ self.hidden_size = hidden_size
206
+ self.intermediate_size = intermediate_size
207
+ self.num_experts = num_experts
208
+ self.num_shared_experts = num_shared_experts
209
+ self.top_k = top_k
210
+ self.num_routed_experts = num_experts - num_shared_experts
211
+ self.shared_experts = nn.ModuleList(
212
+ [DeepSeekExpertLayer(self.hidden_size, self.intermediate_size) for _ in range(self.num_shared_experts)]
213
+ )
214
+ self.routed_experts = nn.ModuleList(
215
+ [DeepSeekExpertLayer(self.hidden_size, self.intermediate_size) for _ in range(self.num_routed_experts)]
216
+ )
217
+
218
+ # Routing Function
219
+ self.routing_fn = nn.Linear(self.hidden_size, self.num_routed_experts, bias=False)
220
+ self.routing_bias = nn.Parameter(torch.zeros(self.num_routed_experts))
221
+ def forward(self, x):
222
+ B, T, C = x.size()
223
+ shared_out = sum(expert(x) for expert in self.shared_experts)
224
+ if self.num_shared_experts>1:
225
+ shared_out = shared_out/self.num_shared_experts # normalize the shared experts
226
+ # calculate the routing function
227
+ routing_logits = self.routing_fn(x) + self.routing_bias # [B, T, num_routed_experts]
228
+ # GEt Topk Experts per token
229
+ routing_probs = torch.sigmoid(routing_logits) # [B, T, num_routed_experts]
230
+ scores, indices = torch.topk(routing_probs, self.top_k, dim=-1) # [B, T, top_k]
231
+ # normalize the top k scores
232
+ scores = scores/torch.sum(scores, dim=-1, keepdim=True)
233
+ # process the routed experts
234
+ #combined_output = torch.zeros(B, T, C, device=x.device)
235
+ combined_output = torch.zeros_like(x)
236
+
237
+ # Calculate expert load for all experts
238
+ expert_load = torch.zeros(self.num_routed_experts, device=x.device)
239
+ for i in range(self.top_k):
240
+ expert_idx = indices[:, :, i] # [B, T, top_k]
241
+
242
+ expert_scores = scores[...,i:i+1]
243
+ # process the routed experts
244
+ for j in range(self.num_routed_experts):
245
+ mask = (expert_idx == j) # [B, T, 1]
246
+ if mask.any():
247
+ # Track expert usage (load)
248
+ expert_load[j] += mask.sum().float() / (B * T * self.top_k)
249
+ # Process tokens through this expert
250
+ expert_input = x[mask] # [B, T, 1, C]
251
+ expert_output = self.routed_experts[j](expert_input)
252
+ combined_output[mask] += expert_scores[mask] * expert_output
253
+ final_output = shared_out + combined_output
254
+ router_z_loss = self.update_bias_terms(expert_load)
255
+ return final_output, router_z_loss
256
+
257
+ def update_bias_terms(self, expert_load, router_z_loss_coef=0.001):
258
+ # Balance expert routing by adjusting the bias terms
259
+ # Target load is uniform distribution across experts
260
+ target_load = 1.0 / self.num_routed_experts
261
+
262
+ # Calculate load imbalance for each expert
263
+ load_diff = expert_load - target_load
264
+
265
+ # Dynamic update rate based on the magnitude of imbalance
266
+ # Larger imbalances get larger corrections
267
+ update_rate = 0.1 * torch.abs(load_diff)
268
+
269
+ # Update the routing bias to counteract imbalance
270
+ # Decrease bias for overutilized experts, increase for underutilized
271
+ self.routing_bias.data -= update_rate * load_diff
272
+
273
+ # Calculate the router z-loss to discourage extreme routing probabilities
274
+ # This helps stabilize training without auxiliary losses
275
+ # Z-loss encourages routing probabilities to stay away from 0 and 1
276
+ router_z_loss = router_z_loss_coef * torch.mean(torch.log(torch.sum(
277
+ torch.exp(self.routing_fn.weight), dim=-1)))
278
+
279
+ return router_z_loss
280
+
281
+ def update_bias_terms_old(self, expert_load, ):
282
+ # adjust the bias terms based on the expert load
283
+ target_load = 1/self.num_experts
284
+ load_diff = expert_load - target_load
285
+ # dyanamic update the bias based on the load imbalance
286
+ update_rate = 0.1 * torch.abs(load_diff)
287
+ # dyanmic update the bias terms using update rate
288
+ self.routing_bias = self.routing_bias - update_rate * load_diff
289
+
290
+ # for i in range(self.num_routed_experts):
291
+ # if expert_load[i] < target_load:
292
+ # self.routing_bias[i] -= 1
293
+ # else:
294
+ # self.routing_bias[i] += 1
295
+ class LlamaMLP(nn.Module):
296
+ """
297
+ (mlp): LlamaMLP(
298
+ (moe): DeepSeekMOE(
299
+ (shared_experts): ModuleList(
300
+ (0): DeepSeekExpertLayer(
301
+ (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
302
+ (up_proj): Linear(in_features=576, out_features=1536, bias=False)
303
+ (down_proj): Linear(in_features=1536, out_features=576, bias=False)
304
+ (act_fn): SiLU()
305
+ )
306
+ )
307
+ (routed_experts): ModuleList(
308
+ (0-2): 3 x DeepSeekExpertLayer(
309
+ (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
310
+ (up_proj): Linear(in_features=576, out_features=1536, bias=False)
311
+ (down_proj): Linear(in_features=1536, out_features=576, bias=False)
312
+ (act_fn): SiLU()
313
+ )
314
+ )
315
+ (routing_fn): Linear(in_features=576, out_features=3, bias=False)
316
+ )
317
+ )
318
+ """
319
+ def __init__(self, config):
320
+ super().__init__()
321
+ self.config = config
322
+ self.moe = DeepSeekMOE(hidden_size=config['hidden_size'],
323
+ intermediate_size=config['intermediate_size'],
324
+ num_experts=config['num_experts'],
325
+ num_shared_experts= config['num_shared_experts'],
326
+ top_k=config['top_k'])
327
+ # self.gate_proj = nn.Linear(self.config['hidden_size'], self.config['intermediate_size'], bias=False)
328
+ # self.up_proj = nn.Linear(self.config['hidden_size'], self.config['intermediate_size'], bias=False)
329
+ # self.down_proj = nn.Linear(self.config['intermediate_size'], self.config['hidden_size'], bias=False)
330
+ # self.act_fn = SiLU()
331
+ def forward(self, x):
332
+ output, router_z_loss = self.moe(x)
333
+ return output, router_z_loss
334
+ # gate = self.gate_proj(x)
335
+ # up = self.up_proj(x)
336
+ # down = self.down_proj(self.act_fn(gate)*up)
337
+ # return down
338
+
339
+ class LlamaRMSNorm(nn.Module):
340
+ """
341
+ (norm): LlamaRMSNorm((576,), eps=1e-05)
342
+ # RMSNorm Formula:
343
+ # RMS(x) = sqrt((1 / d) * sum(x_i^2 for i in range(d)))
344
+ # x_normalized = x / RMS(x)
345
+ # output = gamma * x_normalized
346
+
347
+ """
348
+ def __init__(self, config):
349
+ super().__init__()
350
+ self.config = config
351
+ self.eps = self.config['rms_norm_eps']
352
+ self.weight = nn.Parameter(torch.ones(self.config['hidden_size']))
353
+ def forward(self, x):
354
+ rms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
355
+ return self.weight *rms * x
356
+
357
+ class LlamaDecoderLayer(nn.Module):
358
+ def __init__(self, config):
359
+ super().__init__()
360
+ self.config = config
361
+ self.self_attn = MultiHeadLatentAttention(self.config)
362
+ self.input_layernorm = LlamaRMSNorm(self.config)
363
+ self.mlp = LlamaMLP(self.config)
364
+ self.post_attention_layernorm = LlamaRMSNorm(self.config)
365
+
366
+ def forward(self, x):
367
+ residual = x
368
+ x = self.input_layernorm(x)
369
+ x = self.self_attn(x)
370
+ x = x + residual
371
+ residual = x
372
+ x = self.post_attention_layernorm(x)
373
+ x, router_z_loss = self.mlp(x)
374
+ x = x + residual
375
+ return x, router_z_loss
376
+
377
+ class DeepSeekV3Model(nn.Module):
378
+ def __init__(self, config):
379
+ super().__init__()
380
+ self.init_method = config['init_method']
381
+ self.config = config['model_config']
382
+ self.embed_tokens = nn.Embedding(self.config['vocab_size'], self.config['hidden_size'])
383
+ self.rotary_emb = RotaryPositionalEmbedding(self.config['hidden_size'], self.config['rope_theta'])
384
+ self.layers = nn.ModuleList([LlamaDecoderLayer(self.config) for _ in range(self.config['num_hidden_layers'])])
385
+ self.norm = LlamaRMSNorm(self.config)
386
+ self.lm_head = nn.Linear(self.config['hidden_size'], self.config['vocab_size'], bias=False)
387
+
388
+ if self.config['tie_word_embeddings']:
389
+ self.lm_head.weight = self.embed_tokens.weight
390
+
391
+ self.apply(lambda m: _init_weights(m, self.init_method['std']))
392
+
393
+ def forward(self, x, y=None):
394
+ x = self.embed_tokens(x)
395
+ total_router_z_loss = 0.0
396
+ for layer in self.layers:
397
+ x, router_z_loss = layer(x)
398
+ total_router_z_loss += router_z_loss
399
+ x = self.norm(x)
400
+ logits = self.lm_head(x) # B,T,V
401
+ logits = logits.view(-1, logits.size(-1)) # Shape: [B*T, V] # 20, 49152
402
+ if y is not None:
403
+ y = y.view(-1) # Shape: [B*T] # 20
404
+ ce_loss = torch.nn.functional.cross_entropy(logits, y)
405
+ # Combine CE loss with router z-loss
406
+ loss = ce_loss + total_router_z_loss
407
+ return logits, loss
408
+ else:
409
+ return logits, None
410
+
411
+
412
+ def generate(self, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
413
+ model = self.to(device)
414
+ idx = idx.to(device)
415
+ model.eval()
416
+ for _ in range(max_new_tokens):
417
+ idx_cond = idx[:, -context_length:]
418
+ with torch.no_grad():
419
+ logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
420
+ logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
421
+
422
+ # Get the logits for the last token only
423
+ logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
424
+
425
+ if top_k is not None:
426
+ # top k sampling
427
+ top_logits, top_pos = torch.topk(logits, top_k)
428
+ min_logit = top_logits[:, -1].unsqueeze(-1)
429
+ logits = torch.where(logits < min_logit,
430
+ torch.tensor(float('-inf')).to(logits.device),
431
+ logits)
432
+
433
+ # temperature scaling
434
+ if temperature > 0.0:
435
+ logits /= temperature
436
+ probs = torch.softmax(logits, dim=-1)
437
+ idx_next = torch.multinomial(probs, num_samples=1)
438
+ else:
439
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
440
+
441
+ if idx_next.item() == eos_token:
442
+ break
443
+
444
+ idx = torch.cat((idx, idx_next), dim=1)
445
+ model.train()
446
+ return idx
447
+
448
+ # if __name__ == "__main__":
449
+ # torch.manual_seed(0)
450
+ # config = yaml.load(open("config_smollm2_135M.yaml", "r"), Loader=yaml.FullLoader)
451
+ # print(config.keys())
452
+ # model_config = config['model']['model_config']
453
+ # print(model_config)
454
+ # model = DeepSeekV3Model(config['model'])
455
+ # x_tokens = torch.randint(0, model_config['vocab_size'], (1, 10)) # Generate random token indices
456
+ # print(model(x_tokens).shape)
457
+ # total_params = sum(p.numel() for p in model.parameters())
458
+ # print(f"Total parameters: {total_params}") #134515008
459
+ # print(model)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchtext
3
+ pandas
4
+ numpy==1.26.1
5
+ matplotlib
6
+ tiktoken
7
+ tensorflow>=2.15.0
8
+ tqdm
9
+ # urllib
10
+ requests
11
+ boto3
12
+ datasets
13
+ transformers
14
+ gradio
train.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deepseek_v3 import DeepSeekV3Model
2
+ import torch
3
+ import yaml
4
+ from transformers import AutoTokenizer
5
+ # from gptdataloader import GPTDataLoader
6
+ from torch.utils.data import DataLoader
7
+ import numpy as np
8
+ from datasets import load_dataset
9
+ import logging
10
+ import math
11
+
12
+ from utils import upload_file_to_s3
13
+ # At the start of training loop
14
+ # print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
15
+ # print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
20
+ file_handler = logging.FileHandler('training.log')
21
+ file_handler.setFormatter(formatter) # Set formatter on the handler, not the logger
22
+ logger.addHandler(file_handler)
23
+ logger.setLevel(logging.INFO)
24
+
25
+ def encode_text(examples, tokenizer, seq_length):
26
+ """Tokenize and prepare text examples for training."""
27
+ tokens = tokenizer(
28
+ examples["text"],
29
+ truncation=True,
30
+ padding="max_length",
31
+ max_length=seq_length + 1,
32
+ return_tensors="pt",
33
+ )
34
+ # Use clone().detach() as recommended
35
+ input_ids = tokens["input_ids"].squeeze(0).clone().detach()
36
+ input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1)
37
+ labels = input_ids.clone().detach()
38
+ labels = labels[1:].to(torch.int64)
39
+ input_ids = input_ids[:-1].to(torch.int64)
40
+
41
+ return {"input_ids": input_ids, "labels": labels}
42
+
43
+ def load_cosmopedia_dataset(batch_size=8, seq_length=1024, tokenizer=None):
44
+ """
45
+ Returns a torch dataloader for the cosmopedia dataset
46
+ """
47
+ # Set tokenizer parallelism explicitly
48
+ import os
49
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
50
+ logger.info("tokenizer parallelism set to false")
51
+ try:
52
+ # Increase timeout and retries for dataset loading
53
+ from datasets import config
54
+ config.HF_DATASETS_TIMEOUT = 300 # 5 minutes timeout
55
+ config.MAX_RETRIES = 10 # Increase retry attempts
56
+ logger.info("dataset loading config set")
57
+ train_dataset = load_dataset(
58
+ "HuggingFaceTB/smollm-corpus",
59
+ name="cosmopedia-v2",
60
+ split="train",
61
+ streaming=True,
62
+ )
63
+ logger.info("dataset loaded")
64
+
65
+ # Use partial to bind tokenizer and seq_length to the encode function
66
+ from functools import partial
67
+ encode_fn = partial(encode_text, tokenizer=tokenizer, seq_length=seq_length)
68
+
69
+ train_dataset = train_dataset.map(
70
+ encode_fn,
71
+ remove_columns=["text"],
72
+ batched=False
73
+ )
74
+ train_dataset = train_dataset.with_format("torch")
75
+
76
+ train_dataloader = DataLoader(
77
+ train_dataset,
78
+ batch_size=batch_size,
79
+ num_workers=2,
80
+ pin_memory=True,
81
+ prefetch_factor=4,
82
+ persistent_workers=True
83
+ )
84
+ return train_dataloader
85
+ except Exception as e:
86
+ logger.error(f"Error loading dataset: {str(e)}")
87
+ return None
88
+
89
+ # def create_dataloader(file_path, tokenizer, context_size, stride):
90
+ # with open(file_path, "r") as file:
91
+ # text_data = file.read()
92
+ # total_characters = len(text_data)
93
+ # total_tokens = len(tokenizer.encode(text_data))
94
+ # logger.info(f"Characters: {total_characters}")
95
+ # logger.info(f"Tokens: {total_tokens}")
96
+
97
+ # # create dataloader
98
+ # train_ratio = 0.9
99
+ # val_ratio = 0.1
100
+ # split_idx = int(train_ratio * total_characters)
101
+
102
+ # train_data = text_data[:split_idx]
103
+
104
+ # valid_data = text_data[split_idx:]
105
+
106
+ # train_dataset = GPTDataLoader(train_data, tokenizer, context_size, stride)
107
+ # valid_dataset = GPTDataLoader(valid_data, tokenizer, context_size, stride)
108
+ # return DataLoader(train_dataset, batch_size=10, shuffle=True, drop_last=True), DataLoader(valid_dataset, batch_size=10, shuffle=False, drop_last=True)
109
+
110
+
111
+
112
+
113
+ # def calculate_loss_batch(input_batch, target_batch, model, device):
114
+ # input_batch = input_batch.to(device)
115
+ # target_batch = target_batch.to(device)
116
+
117
+ # logits, loss = model(input_batch, target_batch) # e.g. 10, 32, 49152
118
+ # logits = logits.view(-1, logits.size(-1)) # Shape: [320, 49152]
119
+ # target_batch = target_batch.view(-1) # Shape: [320]
120
+ # loss = torch.nn.functional.cross_entropy(logits, target_batch)
121
+ # return loss
122
+
123
+ # def calc_loss_loader(data_loader, model, device, num_batches=None):
124
+ # total_loss = 0.0
125
+ # if len(data_loader) == 0:
126
+ # return float("nan")
127
+ # elif num_batches is None:
128
+ # num_batches = len(data_loader)
129
+ # else:
130
+ # num_batches = min(num_batches, len(data_loader))
131
+ # for i, (input_batch, target_batch) in enumerate(data_loader):
132
+ # if i < num_batches:
133
+ # loss = calculate_loss_batch(input_batch, target_batch, model, device)
134
+ # total_loss += loss.item()
135
+ # else:
136
+ # break
137
+ # return total_loss / num_batches
138
+
139
+ # def evaluate_model(model, train_dataloader, valid_dataloader, device, eval_iter=100):
140
+ # model.eval()
141
+ # with torch.no_grad():
142
+ # train_loss = calc_loss_loader(train_dataloader, model, device, num_batches=eval_iter)
143
+ # valid_loss = calc_loss_loader(valid_dataloader, model, device, num_batches=eval_iter)
144
+ # model.train()
145
+ # return train_loss, valid_loss
146
+
147
+ def generate(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
148
+ logger.info(f"Generating on device {device}")
149
+ model = model.to(device)
150
+ idx = idx.to(device)
151
+ model.eval()
152
+ for _ in range(max_new_tokens):
153
+ idx_cond = idx[:, -context_length:]
154
+ with torch.no_grad():
155
+ logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
156
+ logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
157
+
158
+ # Get the logits for the last token only
159
+ logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
160
+
161
+ if top_k is not None:
162
+ # top k sampling
163
+ top_logits, top_pos = torch.topk(logits, top_k)
164
+ min_logit = top_logits[:, -1].unsqueeze(-1)
165
+ logits = torch.where(logits < min_logit,
166
+ torch.tensor(float('-inf')).to(logits.device),
167
+ logits)
168
+
169
+ # temperature scaling
170
+ if temperature > 0.0:
171
+ logits /= temperature
172
+ probs = torch.softmax(logits, dim=-1)
173
+ idx_next = torch.multinomial(probs, num_samples=1)
174
+ else:
175
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
176
+
177
+ if idx_next.item() == eos_token:
178
+ break
179
+
180
+ idx = torch.cat((idx, idx_next), dim=1)
181
+ model.train()
182
+ return idx
183
+
184
+ def sync_device(device):
185
+ if device.startswith('cuda'):
186
+ torch.cuda.synchronize()
187
+ elif device == 'cpu':
188
+ torch.cpu.synchronize() if hasattr(torch.cpu, 'synchronize') else None
189
+ elif device.startswith('mps'): # For Apple Silicon
190
+ torch.mps.synchronize()
191
+
192
+ def print_gpu_memory(step_name=""):
193
+ """
194
+ Print GPU memory statistics with a specified step name
195
+ """
196
+ if torch.cuda.is_available():
197
+ logger.info(f"\nGPU Memory Stats {step_name}:")
198
+ logger.info(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
199
+ logger.info(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
200
+ logger.info(f"Max GPU Memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
201
+
202
+ # Learning rate scheduler
203
+ def get_lr_lambda(current_step, warmup_steps, max_steps, max_lr):
204
+ """
205
+ Modified learning rate scheduler with:
206
+ 1. Linear warmup for first 3000 steps
207
+ 2. Cosine decay from 3000 to 60000 steps
208
+ 3. Minimum learning rate of 1.5e-5 (5% of max_lr)
209
+ """
210
+ min_lr = max_lr * 0.05 # Minimum learning rate (5% of max_lr)
211
+
212
+ if current_step < warmup_steps:
213
+ # Linear warmup from 0 to max_lr
214
+ return float(current_step) / float(max(1, warmup_steps))
215
+ else:
216
+ # Cosine decay from max_lr to min_lr
217
+ progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))
218
+ return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))
219
+
220
+
221
+ def train_model(config, model, train_loader, test_loader, optimizer, device, num_epochs, eval_freq, eval_iter, start_context="Jack Gisburn rather a cheap genius- ", tokenizer=None):
222
+ total_loss = 0
223
+ tokens_seen, global_step = 0, -1
224
+
225
+ # Adjusted gradient accumulation setup for batch size 8
226
+ actual_batch_size = config['tokens']['micro_batch_size'] # Now 8
227
+ effective_batch_size_multiplier = 1 # Adjusted for batch size 8
228
+ target_batch_size = effective_batch_size_multiplier * config['tokens']['micro_batch_size']
229
+ gradient_accumulation_steps = target_batch_size // actual_batch_size
230
+
231
+ # Learning rate parameters adjusted for batch size 8
232
+ max_lr = 3e-4 # Keep the same max learning rate
233
+ warmup_steps = 3000 # Keep warmup steps
234
+ max_steps = 60000 # Keep max steps
235
+ min_lr = max_lr * 0.05 # Keep minimum LR at 5% of max
236
+
237
+ # Create LambdaLR scheduler with the improved lambda function
238
+ lr_lambda = lambda step: get_lr_lambda(step, warmup_steps, max_steps, max_lr)
239
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
240
+
241
+ logger.info(f"Training with learning rate schedule:")
242
+ logger.info(f"Max LR: {max_lr}")
243
+ logger.info(f"Warmup Steps: {warmup_steps}")
244
+ logger.info(f"Max Steps: {max_steps}")
245
+ logger.info(f"Min LR: {max_lr * 0.05}")
246
+ logger.info(f"Gradient Accumulation Steps: {gradient_accumulation_steps}")
247
+ logger.info(f"Effective Batch Size: {actual_batch_size * gradient_accumulation_steps}")
248
+
249
+ print_gpu_memory("at start of training")
250
+
251
+ # Add these near the start of training loop
252
+ torch.cuda.empty_cache()
253
+ torch.backends.cudnn.benchmark = True
254
+ for epoch in range(num_epochs):
255
+ model.train()
256
+ optimizer.zero_grad() # Zero gradients at start of epoch
257
+
258
+ for batch_idx, batch in enumerate(train_loader):
259
+ input_batch = batch['input_ids'].to(device)
260
+ target_batch = batch['labels'].to(device)
261
+
262
+ # Forward pass
263
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
264
+ logits, original_loss = model(input_batch, target_batch)
265
+
266
+ # Scale loss for gradient accumulation
267
+ scaled_loss = original_loss / gradient_accumulation_steps
268
+ scaled_loss.backward()
269
+
270
+ # Add the original loss to total_loss for logging
271
+ total_loss += original_loss.item() # Don't multiply back up
272
+ tokens_seen += input_batch.numel()
273
+
274
+ # Calculate running average loss
275
+ total_batches = batch_idx + 1
276
+ avg_loss = total_loss / total_batches
277
+ if batch_idx % 25 == 0:
278
+ logger.info(f"Batch {batch_idx + 1}, Running Avg Loss: {avg_loss:.5f}")
279
+ # Only update weights after accumulating gradients
280
+ if (batch_idx + 1) % gradient_accumulation_steps == 0:
281
+ # Gradient clipping
282
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
283
+
284
+ optimizer.step()
285
+ scheduler.step() # Update learning rate
286
+ optimizer.zero_grad()
287
+ global_step += 1
288
+
289
+ # Evaluation block
290
+ if global_step % eval_freq == 0 and global_step > 0:
291
+ # Use total batches processed instead of global_step
292
+ current_lr = scheduler.get_last_lr()[0]
293
+ optimizer_lr = optimizer.param_groups[0]['lr']
294
+
295
+ print_gpu_memory(f"at step {global_step}")
296
+ logger.info(f"learning rate: {current_lr:.8f}")
297
+ logger.info(f"Ep {epoch+1} (Step {global_step:06d}): "
298
+ f"Avg loss {avg_loss:.3f} | {tokens_seen} tokens seen")
299
+ logger.info(f"optimizer lr: {optimizer_lr:.8f}")
300
+ logger.info(f"scheduler lr: {current_lr:.8f}")
301
+
302
+ # Generate sample text
303
+ start_context_list = ["In today's ever-evolving world, technology has become an integral part of our lives","Once upon a time, there was a friendly agency called Gaudette Insurance Agency, Inc. They help","A couple of years ago, I was working as an extra on the set of a low-budget British film.","Introduction: The Art of Crafting Vegan Sandwich Delights Sandwiches occupy a unique space in","Meet Chris, a superhero of supplies! Just like how Batman protects Gotham City","Identity formation is a complex and multifaceted process that involves the development of", "With the development of science and technology, computer has become more and more ","Just as there are many variants and forms of electronic malware and Internet-based ","Correctly identifying what is causing a problem is the most important step in pest control.","Lobster, California spiny The California Spiny Lobster fishery is a small but locally ","Bees are vital for pollination. You can buy leafcutter bee houses to attract ","Feeling Alone Together: Exploring Alienation and Isolation in Literature", "Imagine if someone got their hands on dangerous weapons","Once upon a time, in a colorful town called Popville, ","he bell above the door jangled as Sarah walked into her family's hardware store"]
304
+ # Randomly select a prompt from the list
305
+ random_prompt = np.random.choice(start_context_list)
306
+ logger.info(f"Selected prompt: {random_prompt}")
307
+ logger.info(f"+++"*30)
308
+ encoded_text = tokenizer.encode(random_prompt, return_tensors="pt")
309
+ random_topk = np.random.randint(1, 10)
310
+ logger.info(f"random_topk: {random_topk}")
311
+ random_temperature = np.random.uniform(0.7, 0.9)
312
+ logger.info(f"random_temperature: {random_temperature}")
313
+ logger.info(f"global step {global_step} , batch_idx {batch_idx} => generating text")
314
+ generated_text = generate(model,
315
+ idx=encoded_text,
316
+ max_new_tokens=256,
317
+ context_length=256,
318
+ temperature=random_temperature,
319
+ top_k=random_topk,
320
+ eos_token=tokenizer.eos_token_id,
321
+ device=device)
322
+ logger.info(f"+++"*30)
323
+ logger.info(tokenizer.decode(generated_text.squeeze(0)))
324
+ logger.info(f"+++"*30)
325
+
326
+ # Save checkpoint
327
+ model_file_name = f"model_{global_step}_steps_avg_loss_{avg_loss:.5f}_optimizer_lr_{optimizer_lr:.8f}.pth"
328
+ torch.save({
329
+ 'step': global_step,
330
+ 'model_state_dict': model.state_dict(),
331
+ 'optimizer_state_dict': optimizer.state_dict(),
332
+ 'scheduler_state_dict': scheduler.state_dict(),
333
+ 'loss': avg_loss,
334
+ }, model_file_name)
335
+
336
+ s3_path = upload_file_to_s3(model_file_name, config['model']['model_config']['s3_bucket'],
337
+ config['model']['model_config']['s3_checkpoint_folder'])
338
+ logger.info(f"Model saved to S3: {s3_path}")
339
+
340
+ log_path = upload_file_to_s3(config['model']['model_config']['s3_log_file_name'], config['model']['model_config']['s3_bucket'],
341
+ config['model']['model_config']['s3_log_folder'])
342
+ logger.info(f"Log saved to S3: {log_path}")
343
+
344
+ if batch_idx % 100 == 0:
345
+ logger.info(f"Batch {batch_idx} finished")
346
+ logger.info(f"+++"*30)
347
+
348
+ logger.info("Training complete")
349
+
350
+ if __name__ == "__main__":
351
+ config = yaml.load(open("config_smollm2_135M.yaml", "r"), Loader=yaml.FullLoader)
352
+ logger.info(config)
353
+
354
+ # Set memory efficient settings
355
+ torch.set_float32_matmul_precision('high')
356
+ torch.backends.cudnn.benchmark = True
357
+ torch.backends.cuda.matmul.allow_tf32 = True
358
+
359
+ # Empty cache before model creation
360
+ torch.cuda.empty_cache()
361
+ import os
362
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64'
363
+
364
+ model = DeepSeekV3Model(config['model'])
365
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
366
+
367
+ # Enable gradient checkpointing for memory efficiency
368
+ # model.gradient_checkpointing_enable()
369
+
370
+ model.to(device)
371
+ #model = torch.compile(model)
372
+ logger.info(model)
373
+ logger.info("++"*30)
374
+ total_params = sum(p.numel() for p in model.parameters())
375
+ logger.info(f"Total parameters: {total_params}")
376
+
377
+ optimizer = torch.optim.AdamW(
378
+ model.parameters(),
379
+ lr=3e-4,
380
+ weight_decay=0.15,
381
+ betas=(0.9, 0.95)
382
+ )
383
+
384
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
385
+ tokenizer.pad_token = tokenizer.eos_token
386
+ vocab_size = tokenizer.vocab_size
387
+
388
+ # Adjusted batch size to 8
389
+ train_loader = load_cosmopedia_dataset(
390
+ batch_size=8, # Changed from 4 to 8
391
+ seq_length=512, # Kept at 512
392
+ tokenizer=tokenizer
393
+ )
394
+
395
+ import time
396
+ t1 = time.time()
397
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
398
+
399
+ # Set environment variable for memory allocation
400
+ import os
401
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
402
+
403
+ train_model(
404
+ config,
405
+ model,
406
+ train_loader,
407
+ train_loader,
408
+ optimizer=optimizer,
409
+ device=device,
410
+ num_epochs=1,
411
+ eval_freq=2500, # Increase eval frequency to every 500 steps
412
+ eval_iter=2500,
413
+ start_context="Once Upon a Time far far away in a galaxy",
414
+ tokenizer=tokenizer
415
+ )
416
+ t2 = time.time()
417
+ logger.info(f"Time taken for training: {t2 - t1:.2f} seconds")
utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ from boto3.s3.transfer import TransferConfig
3
+ from tqdm import tqdm
4
+ import os
5
+
6
+ def upload_file_to_s3(file_path, bucket_name, s3_prefix):
7
+
8
+
9
+ class ProgressPercentage(object):
10
+ def __init__(self, filename):
11
+ self._filename = filename
12
+ self._size = float(os.path.getsize(filename))
13
+ self._seen_so_far = 0
14
+ self._pbar = tqdm(total=self._size, unit='B', unit_scale=True, desc=f"Uploading {os.path.basename(filename)}")
15
+
16
+ def __call__(self, bytes_amount):
17
+ self._seen_so_far += bytes_amount
18
+ self._pbar.update(bytes_amount)
19
+
20
+ s3_client = boto3.client('s3')
21
+ file_name = os.path.basename(file_path)
22
+ s3_path = f"{s3_prefix}/{file_name}"
23
+
24
+ # Configure multipart upload
25
+ config = TransferConfig(
26
+ multipart_threshold=1024 * 25, # 25MB
27
+ max_concurrency=10,
28
+ multipart_chunksize=1024 * 25, # 25MB
29
+ use_threads=True
30
+ )
31
+
32
+ try:
33
+ s3_client.upload_file(
34
+ file_path,
35
+ bucket_name,
36
+ s3_path,
37
+ Config=config,
38
+ Callback=ProgressPercentage(file_path)
39
+ )
40
+ return f"s3://{bucket_name}/{s3_path}"
41
+ except Exception as e:
42
+ print(f"Failed to upload {file_path} to S3: {str(e)}")
43
+ return None
44
+
45
+ max_lr = 1e-3
46
+ warmup_steps = 10
47
+ max_steps = 25000
48
+ import math
49
+ def get_lr_lambda(current_step, warmup_steps, max_steps, max_lr):
50
+ """
51
+ Learning rate scheduler with:
52
+ 1. Linear warmup
53
+ 2. Cosine decay
54
+ 3. Minimum learning rate of 10% of max_lr
55
+ """
56
+ min_lr = max_lr * 0.1 # Minimum learning rate (10% of max_lr)
57
+
58
+ if current_step < warmup_steps:
59
+ # Linear warmup
60
+ return max_lr * (current_step + 1) / warmup_steps
61
+ elif current_step > max_steps:
62
+ # After max_steps, return minimum learning rate
63
+ return min_lr
64
+ else:
65
+ # Cosine decay between warmup_steps and max_steps
66
+ decay_ratio = (current_step - warmup_steps) / (max_steps - warmup_steps)
67
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
68
+ return min_lr + coeff * (max_lr - min_lr)
69
+
70
+
71
+ def plot_lr_schedule():
72
+ """
73
+ Helper function to visualize the learning rate schedule
74
+ """
75
+ import matplotlib.pyplot as plt
76
+ steps = list(range(0, max_steps + 100))
77
+ lrs = [get_lr_lambda(step, warmup_steps, max_steps, max_lr) for step in steps]
78
+
79
+ plt.figure(figsize=(10, 5))
80
+ plt.plot(steps, lrs)
81
+ plt.title('Learning Rate Schedule')
82
+ plt.xlabel('Steps')
83
+ plt.ylabel('Learning Rate')
84
+ plt.grid(True)
85
+ plt.show()
86
+
87
+ def plot_training_loss(log_file_path, output_path=None):
88
+ """
89
+ Parse a training log file and plot the running average loss against batch steps.
90
+ Also adds a trend line to visualize the overall training progress.
91
+
92
+ Args:
93
+ log_file_path (str): Path to the training log file
94
+ output_path (str, optional): Path to save the plot as PNG. If None, displays the plot instead.
95
+ """
96
+ import re
97
+ import matplotlib.pyplot as plt
98
+ import numpy as np
99
+ from scipy.optimize import curve_fit
100
+
101
+ # Regular expression to extract batch number and loss
102
+ pattern = r"Batch (\d+), Running Avg Loss: ([0-9.]+)"
103
+
104
+ steps = []
105
+ losses = []
106
+
107
+ # Read and parse the log file
108
+ with open(log_file_path, 'r') as file:
109
+ for line in file:
110
+ match = re.search(pattern, line)
111
+ if match:
112
+ batch_num = int(match.group(1))
113
+ loss = float(match.group(2))
114
+ steps.append(batch_num)
115
+ losses.append(loss)
116
+
117
+ if not steps:
118
+ print("No loss data found in the log file.")
119
+ return
120
+
121
+ # Create the plot
122
+ plt.figure(figsize=(12, 6))
123
+ plt.plot(steps, losses, 'b-', alpha=0.5, label='Running Avg Loss')
124
+
125
+ # Add trend line (using polynomial fit)
126
+ def poly_func(x, a, b, c):
127
+ return a * x**2 + b * x + c
128
+
129
+ # Convert to numpy arrays for curve fitting
130
+ x_array = np.array(steps)
131
+ y_array = np.array(losses)
132
+
133
+ # Fit the curve
134
+ try:
135
+ popt, _ = curve_fit(poly_func, x_array, y_array)
136
+ x_line = np.linspace(min(steps), max(steps), 1000)
137
+ y_line = poly_func(x_line, *popt)
138
+ plt.plot(x_line, y_line, 'r-', label='Trend Line')
139
+ except Exception as e:
140
+ print(f"Could not fit trend line: {e}")
141
+ # Fallback to simple moving average for trend
142
+ window_size = min(len(steps) // 10, 100) if len(steps) > 100 else len(steps) // 2
143
+ if window_size > 0:
144
+ moving_avg = np.convolve(y_array, np.ones(window_size)/window_size, mode='valid')
145
+ plt.plot(steps[window_size-1:], moving_avg, 'r-', label='Moving Average Trend')
146
+
147
+ # Add labels and title
148
+ plt.xlabel('Batch Number')
149
+ plt.ylabel('Running Average Loss')
150
+ plt.title('Training Loss Over Time')
151
+ plt.grid(True)
152
+ plt.legend()
153
+
154
+ # Add min and max loss annotations
155
+ min_loss = min(losses)
156
+ min_idx = losses.index(min_loss)
157
+ max_loss = max(losses)
158
+ max_idx = losses.index(max_loss)
159
+
160
+ plt.annotate(f'Min: {min_loss:.5f}',
161
+ xy=(steps[min_idx], min_loss),
162
+ xytext=(steps[min_idx], min_loss*1.05),
163
+ arrowprops=dict(facecolor='green', shrink=0.05),
164
+ fontsize=10)
165
+
166
+ plt.annotate(f'Max: {max_loss:.5f}',
167
+ xy=(steps[max_idx], max_loss),
168
+ xytext=(steps[max_idx], max_loss*0.95),
169
+ arrowprops=dict(facecolor='red', shrink=0.05),
170
+ fontsize=10)
171
+
172
+ # Save or show the plot
173
+ plt.tight_layout()
174
+ if output_path:
175
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
176
+ print(f"Plot saved to {output_path}")
177
+ else:
178
+ plt.show()
179
+
180
+ if __name__ == "__main__":
181
+ # plot_lr_schedule()
182
+ plot_training_loss("training.log", "train_loss.png")