Spaces:
Sleeping
Sleeping
gradio app
Browse files- app.py +177 -0
- config_smollm2_135M.yaml +108 -0
- deepseek_v3.py +459 -0
- requirements.txt +14 -0
- train.py +417 -0
- utils.py +182 -0
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
import yaml
|
5 |
+
from deepseek_v3 import DeepSeekV3Model
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
def generate_helper(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
|
10 |
+
|
11 |
+
model = model.to(device)
|
12 |
+
idx = idx.to(device)
|
13 |
+
model.eval()
|
14 |
+
for _ in range(max_new_tokens):
|
15 |
+
idx_cond = idx[:, -context_length:]
|
16 |
+
with torch.no_grad():
|
17 |
+
logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
|
18 |
+
logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
|
19 |
+
|
20 |
+
# Get the logits for the last token only
|
21 |
+
logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
|
22 |
+
|
23 |
+
if top_k is not None:
|
24 |
+
# top k sampling
|
25 |
+
top_logits, top_pos = torch.topk(logits, top_k)
|
26 |
+
min_logit = top_logits[:, -1].unsqueeze(-1)
|
27 |
+
logits = torch.where(logits < min_logit,
|
28 |
+
torch.tensor(float('-inf')).to(logits.device),
|
29 |
+
logits)
|
30 |
+
|
31 |
+
# temperature scaling
|
32 |
+
if temperature > 0.0:
|
33 |
+
logits /= temperature
|
34 |
+
probs = torch.softmax(logits, dim=-1)
|
35 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
36 |
+
else:
|
37 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
38 |
+
|
39 |
+
if idx_next.item() == eos_token:
|
40 |
+
break
|
41 |
+
|
42 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
43 |
+
model.train()
|
44 |
+
return idx
|
45 |
+
|
46 |
+
def get_config(config_path):
|
47 |
+
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
|
48 |
+
return config
|
49 |
+
|
50 |
+
def extract_and_save_weights(config_path, checkpoint_path, weights_path, device):
|
51 |
+
"""Extract model weights from checkpoint and save as a separate .pt file"""
|
52 |
+
print(f"Extracting weights from checkpoint: {checkpoint_path}")
|
53 |
+
config = get_config(config_path)
|
54 |
+
model = DeepSeekV3Model(config['model'])
|
55 |
+
|
56 |
+
# Load checkpoint
|
57 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
58 |
+
state_dict = checkpoint['model_state_dict']
|
59 |
+
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
60 |
+
|
61 |
+
# Save just the model weights
|
62 |
+
torch.save(state_dict, weights_path)
|
63 |
+
print(f"Model weights saved to: {weights_path}")
|
64 |
+
return state_dict
|
65 |
+
|
66 |
+
def load_weights(config, weights_path, device):
|
67 |
+
"""Load model from weights file"""
|
68 |
+
print(f"Loading model from weights: {weights_path}")
|
69 |
+
model = DeepSeekV3Model(config['model'])
|
70 |
+
state_dict = torch.load(weights_path, map_location=torch.device(device))
|
71 |
+
model.load_state_dict(state_dict)
|
72 |
+
return model
|
73 |
+
|
74 |
+
def get_tokenizer(config):
|
75 |
+
tokenizer_path = config['tokenizer']['tokenizer_name_or_path']
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
77 |
+
tokenizer.pad_token = tokenizer.eos_token
|
78 |
+
vocab_size = tokenizer.vocab_size
|
79 |
+
return tokenizer, vocab_size
|
80 |
+
|
81 |
+
def generate_text(model, tokenizer, input_text, max_new_tokens, context_length, temperature, top_k, eos_token, device):
|
82 |
+
encoded_text = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
83 |
+
generated_text = generate_helper(model,
|
84 |
+
idx=encoded_text,
|
85 |
+
max_new_tokens=max_new_tokens,
|
86 |
+
context_length=context_length,
|
87 |
+
temperature=temperature,
|
88 |
+
top_k=top_k,
|
89 |
+
eos_token=eos_token,
|
90 |
+
device=device)
|
91 |
+
return tokenizer.decode(generated_text.squeeze(0))
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
# Initialize model and tokenizer
|
96 |
+
def initialize_model():
|
97 |
+
config_path = "config_smollm2_135M.yaml"
|
98 |
+
# Use HF Hub or another external storage instead of local path
|
99 |
+
model_id = "crpatel/DeepSeek-V3-SmolLm2" # Replace with your actual model ID
|
100 |
+
weights_path = "model.pt"
|
101 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
102 |
+
|
103 |
+
# Load configuration
|
104 |
+
config = get_config(config_path)
|
105 |
+
|
106 |
+
# Check if weights exist locally, otherwise download from HF Hub
|
107 |
+
if not os.path.exists(weights_path):
|
108 |
+
try:
|
109 |
+
from huggingface_hub import hf_hub_download
|
110 |
+
print(f"Downloading model weights from Hugging Face Hub: {model_id}")
|
111 |
+
weights_path = hf_hub_download(
|
112 |
+
repo_id=model_id,
|
113 |
+
filename="model.pt"
|
114 |
+
)
|
115 |
+
except Exception as e:
|
116 |
+
print(f"Error downloading weights: {e}")
|
117 |
+
print("Falling back to local checkpoint extraction if available")
|
118 |
+
checkpoint_path = "checkpoints/model_100000_step_avg_loss_4.61663.pth"
|
119 |
+
if os.path.exists(checkpoint_path):
|
120 |
+
extract_and_save_weights(config_path, checkpoint_path, weights_path, device)
|
121 |
+
else:
|
122 |
+
raise FileNotFoundError(f"Neither weights file nor checkpoint found. Please upload model to HF Hub first.")
|
123 |
+
|
124 |
+
# Load model from weights
|
125 |
+
model = load_weights(config, weights_path, device)
|
126 |
+
model.to(device)
|
127 |
+
model.eval()
|
128 |
+
|
129 |
+
# Load tokenizer
|
130 |
+
tokenizer, vocab_size = get_tokenizer(config)
|
131 |
+
|
132 |
+
return model, tokenizer, device
|
133 |
+
|
134 |
+
def generate_response(prompt, max_new_tokens):
|
135 |
+
generated_text = generate_text(
|
136 |
+
model=model,
|
137 |
+
tokenizer=tokenizer,
|
138 |
+
input_text=prompt,
|
139 |
+
max_new_tokens=max_new_tokens,
|
140 |
+
context_length=256,
|
141 |
+
temperature=0.9,
|
142 |
+
top_k=2,
|
143 |
+
eos_token=tokenizer.eos_token_id,
|
144 |
+
device=device
|
145 |
+
)
|
146 |
+
return generated_text
|
147 |
+
|
148 |
+
# Initialize global variables
|
149 |
+
model, tokenizer, device = initialize_model()
|
150 |
+
|
151 |
+
# Create Gradio interface
|
152 |
+
iface = gr.Interface(
|
153 |
+
fn=generate_response,
|
154 |
+
inputs=[
|
155 |
+
gr.Textbox(
|
156 |
+
lines=3,
|
157 |
+
placeholder="Enter your prompt here...",
|
158 |
+
label="Input Prompt"
|
159 |
+
),
|
160 |
+
gr.Slider(
|
161 |
+
minimum=50,
|
162 |
+
maximum=256,
|
163 |
+
value=100,
|
164 |
+
step=10,
|
165 |
+
label="Max New Tokens"
|
166 |
+
)
|
167 |
+
],
|
168 |
+
outputs=gr.Textbox(
|
169 |
+
lines=5,
|
170 |
+
label="Generated Text"
|
171 |
+
),
|
172 |
+
title="DeepSeek-V3 Text Generator",
|
173 |
+
description="Enter a prompt and adjust the maximum number of tokens to generate text with DeepSeek-V3 SmolLM2 model."
|
174 |
+
)
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
iface.launch()
|
config_smollm2_135M.yaml
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoints:
|
2 |
+
checkpoint_interval: 2000
|
3 |
+
checkpoints_path: checkpoints
|
4 |
+
checkpoints_path_is_shared_file_system: false
|
5 |
+
resume_checkpoint_path: null
|
6 |
+
save_final_state: false
|
7 |
+
save_initial_state: false
|
8 |
+
data_stages:
|
9 |
+
- data:
|
10 |
+
dataset:
|
11 |
+
dataset_folder:
|
12 |
+
- datasets/smollm2-corpus
|
13 |
+
dataset_weights:
|
14 |
+
- 1.0
|
15 |
+
num_loading_workers: 0
|
16 |
+
seed: 8
|
17 |
+
name: stable phase
|
18 |
+
start_training_step: 1
|
19 |
+
general:
|
20 |
+
benchmark_csv_path: null
|
21 |
+
consumed_train_samples: null
|
22 |
+
ignore_sanity_checks: true
|
23 |
+
project: smollm2
|
24 |
+
run: smollm2-135M
|
25 |
+
seed: 8
|
26 |
+
step: null
|
27 |
+
logging:
|
28 |
+
iteration_step_info_interval: 1
|
29 |
+
log_level: info
|
30 |
+
log_level_replica: info
|
31 |
+
model:
|
32 |
+
ddp_bucket_cap_mb: 25
|
33 |
+
dtype: bfloat16
|
34 |
+
init_method:
|
35 |
+
std: 0.041666666666666664
|
36 |
+
make_vocab_size_divisible_by: 1
|
37 |
+
model_config:
|
38 |
+
bos_token_id: 0
|
39 |
+
eos_token_id: 0
|
40 |
+
hidden_act: silu
|
41 |
+
hidden_size: 576
|
42 |
+
initializer_range: 0.041666666666666664
|
43 |
+
intermediate_size: 1536
|
44 |
+
is_llama_config: true
|
45 |
+
max_position_embeddings: 2048
|
46 |
+
num_attention_heads: 9
|
47 |
+
num_hidden_layers: 30
|
48 |
+
num_key_value_heads: 3
|
49 |
+
pad_token_id: null
|
50 |
+
pretraining_tp: 1
|
51 |
+
rms_norm_eps: 1.0e-05
|
52 |
+
rope_interleaved: false
|
53 |
+
rope_scaling: null
|
54 |
+
rope_theta: 10000.0
|
55 |
+
tie_word_embeddings: true
|
56 |
+
use_cache: true
|
57 |
+
vocab_size: 49152
|
58 |
+
s3_bucket: deepseek-v3-train-mar-2025
|
59 |
+
s3_checkpoint_folder: checkpoints
|
60 |
+
s3_log_folder: logs
|
61 |
+
s3_log_file_name: training.log
|
62 |
+
# deepseek
|
63 |
+
compression_ratio: 4
|
64 |
+
num_experts: 4
|
65 |
+
num_shared_experts: 1
|
66 |
+
top_k: 2
|
67 |
+
optimizer:
|
68 |
+
accumulate_grad_in_fp32: true
|
69 |
+
clip_grad: 1.0
|
70 |
+
learning_rate_scheduler:
|
71 |
+
learning_rate: 0.003
|
72 |
+
lr_decay_starting_step: 1600000
|
73 |
+
lr_decay_steps: 400000
|
74 |
+
lr_decay_style: linear
|
75 |
+
lr_warmup_steps: 2000
|
76 |
+
lr_warmup_style: linear
|
77 |
+
min_decay_lr: 0
|
78 |
+
optimizer_factory:
|
79 |
+
adam_beta1: 0.9
|
80 |
+
adam_beta2: 0.95
|
81 |
+
adam_eps: 1.0e-08
|
82 |
+
name: adamW
|
83 |
+
torch_adam_is_fused: true
|
84 |
+
weight_decay: 0.01
|
85 |
+
zero_stage: 0
|
86 |
+
parallelism:
|
87 |
+
dp: 64
|
88 |
+
expert_parallel_size: 1
|
89 |
+
pp: 1
|
90 |
+
pp_engine: 1f1b
|
91 |
+
recompute_layer: false
|
92 |
+
tp: 1
|
93 |
+
tp_linear_async_communication: true
|
94 |
+
tp_mode: REDUCE_SCATTER
|
95 |
+
tp_recompute_allgather: true
|
96 |
+
profiler: null
|
97 |
+
tokenizer:
|
98 |
+
tokenizer_max_length: null
|
99 |
+
tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
|
100 |
+
tokenizer_revision: null
|
101 |
+
tokens:
|
102 |
+
batch_accumulation_per_replica: 1
|
103 |
+
limit_test_batches: 0
|
104 |
+
limit_val_batches: 0
|
105 |
+
micro_batch_size: 8
|
106 |
+
sequence_length: 512
|
107 |
+
train_steps: 2000000
|
108 |
+
val_check_interval: 1000
|
deepseek_v3.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import SiLU
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
|
8 |
+
def _init_weights(module, std=0.041666666666666664):
|
9 |
+
if isinstance(module, nn.Linear):
|
10 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
11 |
+
elif isinstance(module, nn.Embedding):
|
12 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
13 |
+
|
14 |
+
class RotaryPositionalEmbedding(nn.Module):
|
15 |
+
"""
|
16 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L240
|
17 |
+
Rotary Positional Embedding (RoPE) for transformers Implemntation derived from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
18 |
+
"""
|
19 |
+
def __init__(self, dim: int, theta: float = 10000.0):
|
20 |
+
super().__init__()
|
21 |
+
self.dim = dim
|
22 |
+
self.theta = theta
|
23 |
+
|
24 |
+
def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
|
25 |
+
"""
|
26 |
+
Apply rotary positional embedding to the input tensor.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
x (torch.Tensor): Input tensor of shape [B, T, H, D] or [B, T, D]
|
30 |
+
seq_len (int): Sequence length.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
torch.Tensor: Output tensor with rotary positional embeddings applied.
|
34 |
+
"""
|
35 |
+
# Handle different input shapes
|
36 |
+
if len(x.shape) == 3:
|
37 |
+
B, T, D = x.shape
|
38 |
+
is_4d = False
|
39 |
+
else:
|
40 |
+
B, T, H, D = x.shape
|
41 |
+
is_4d = True
|
42 |
+
|
43 |
+
# For 3D tensors, we need to ensure D is even
|
44 |
+
if not is_4d and D % 2 != 0:
|
45 |
+
raise ValueError(f"Feature dimension {D} must be divisible by 2 for RoPE")
|
46 |
+
|
47 |
+
# Generate position indices
|
48 |
+
position = torch.arange(T, dtype=torch.float32, device=x.device).unsqueeze(-1)
|
49 |
+
|
50 |
+
# Generate frequencies
|
51 |
+
if is_4d:
|
52 |
+
# For 4D tensors, use the head dimension
|
53 |
+
freqs = torch.exp(
|
54 |
+
torch.arange(0, D, 2, dtype=torch.float32, device=x.device) *
|
55 |
+
-(torch.log(torch.tensor(self.theta)) / D)
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
# For 3D tensors, use the full dimension
|
59 |
+
freqs = torch.exp(
|
60 |
+
torch.arange(0, D, 2, dtype=torch.float32, device=x.device) *
|
61 |
+
-(torch.log(torch.tensor(self.theta)) / D)
|
62 |
+
)
|
63 |
+
|
64 |
+
# Compute sinusoids
|
65 |
+
sinusoid = position * freqs
|
66 |
+
sin = torch.sin(sinusoid)
|
67 |
+
cos = torch.cos(sinusoid)
|
68 |
+
|
69 |
+
# Reshape sin and cos to match the input tensor's shape
|
70 |
+
if is_4d:
|
71 |
+
sin = sin.unsqueeze(0).unsqueeze(2) # Shape: (1, T, 1, D // 2)
|
72 |
+
cos = cos.unsqueeze(0).unsqueeze(2) # Shape: (1, T, 1, D // 2)
|
73 |
+
else:
|
74 |
+
sin = sin.unsqueeze(0) # Shape: (1, T, D // 2)
|
75 |
+
cos = cos.unsqueeze(0) # Shape: (1, T, D // 2)
|
76 |
+
|
77 |
+
# Apply rotary embeddings
|
78 |
+
x_rotated = x.clone()
|
79 |
+
|
80 |
+
if is_4d:
|
81 |
+
x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin
|
82 |
+
x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin
|
83 |
+
else:
|
84 |
+
x_rotated[..., 0::2] = x[..., 0::2] * cos - x[..., 1::2] * sin
|
85 |
+
x_rotated[..., 1::2] = x[..., 1::2] * cos + x[..., 0::2] * sin
|
86 |
+
|
87 |
+
return x_rotated
|
88 |
+
|
89 |
+
class MultiHeadLatentAttention(nn.Module):
|
90 |
+
def __init__(self, config):
|
91 |
+
super().__init__()
|
92 |
+
self.config = config
|
93 |
+
self.num_attention_heads = self.config['num_attention_heads']
|
94 |
+
self.hidden_size = self.config['hidden_size']
|
95 |
+
# Ensure the hidden size is divisible by the number of attention heads
|
96 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
97 |
+
raise ValueError(
|
98 |
+
f"hidden_size ({self.hidden_size}) must be divisible by num_attention_heads ({self.num_attention_heads})"
|
99 |
+
)
|
100 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
101 |
+
self.latent_dim = self.hidden_size // self.config['compression_ratio']
|
102 |
+
|
103 |
+
# Matrix is decomposed into D and U matrix
|
104 |
+
# Compression KV Projection Matrix
|
105 |
+
self.kv_proj_D = nn.Linear(self.hidden_size, self.latent_dim, bias=False)
|
106 |
+
# Compression Q Projection Matrix
|
107 |
+
self.q_proj_D = nn.Linear(self.hidden_size, self.latent_dim, bias=False)
|
108 |
+
|
109 |
+
# UnCompression k projection matrix
|
110 |
+
self.k_proj_U = nn.Linear(self.latent_dim, self.hidden_size//2, bias=False)
|
111 |
+
# UnCompression v projection matrix
|
112 |
+
self.v_proj_U = nn.Linear(self.latent_dim, self.hidden_size, bias=False)
|
113 |
+
# UnCompression Q projection matrix
|
114 |
+
self.q_proj_U = nn.Linear(self.latent_dim, self.hidden_size//2, bias=False)
|
115 |
+
|
116 |
+
# Rope Key Components, K is built from X and Q is build from q_proj_D
|
117 |
+
self.rope_k = nn.Linear(self.hidden_size, self.hidden_size//2, bias=False)
|
118 |
+
self.rope_q = nn.Linear(self.latent_dim, self.hidden_size//2, bias=False)
|
119 |
+
# output projection matrix
|
120 |
+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
121 |
+
|
122 |
+
self.rotary_emb = RotaryPositionalEmbedding(self.hidden_size//2, self.config['rope_theta'])
|
123 |
+
|
124 |
+
def forward(self, x, attn_mask=None):
|
125 |
+
B, T, C = x.size() # Batch Size, Sequence Length, Hidden Size
|
126 |
+
# Compression KV Projection Matrix
|
127 |
+
kv_d = self.kv_proj_D(x) # [B, T, Latent Dim]
|
128 |
+
# Compression Q Projection Matrix
|
129 |
+
q_d = self.q_proj_D(x) # [B, T, Latent Dim]
|
130 |
+
# Uncompress KV & Q Projection Matrix
|
131 |
+
k_proj_2 = self.k_proj_U(kv_d) # [B, T, Hidden Size//2]
|
132 |
+
q_proj_2 = self.q_proj_U(q_d) # [B, T, Hidden Size//2]
|
133 |
+
v = self.v_proj_U(kv_d) # [B, T, Hidden Size]
|
134 |
+
|
135 |
+
# Rope components
|
136 |
+
k_rope_2 = self.rope_k(x) # [B, T, Hidden Size//2]
|
137 |
+
q_rope_2 = self.rope_q(q_d) # [B, T, Hidden Size//2]
|
138 |
+
|
139 |
+
# Apply ROPE to the rope components
|
140 |
+
k_rope_2 = self.rotary_emb(k_rope_2, T) # [B, T, Hidden Size//2]
|
141 |
+
q_rope_2 = self.rotary_emb(q_rope_2, T) # [B, T, Hidden Size//2]
|
142 |
+
|
143 |
+
# Reshape Components for Multi-Head Attention
|
144 |
+
k_proj_2 = k_proj_2.view(B, T, self.num_attention_heads, self.head_dim//2)
|
145 |
+
k_rope_2 = k_rope_2.view(B, T, self.num_attention_heads, self.head_dim//2)
|
146 |
+
q_proj_2 = q_proj_2.view(B, T, self.num_attention_heads, self.head_dim//2)
|
147 |
+
q_rope_2 = q_rope_2.view(B, T, self.num_attention_heads, self.head_dim//2)
|
148 |
+
|
149 |
+
# Concatenate Components
|
150 |
+
k = torch.cat((k_proj_2, k_rope_2), dim=-1) # [B, T, H, D]
|
151 |
+
q = torch.cat((q_proj_2, q_rope_2), dim=-1) # [B, T, H, D]
|
152 |
+
v = v.view(B, T, self.num_attention_heads, self.head_dim)
|
153 |
+
|
154 |
+
# Reshape Components for Multi-Head Attention
|
155 |
+
k = k.transpose(1, 2) # [B, H, T, D]
|
156 |
+
q = q.transpose(1, 2) # [B, H, T, D]
|
157 |
+
v = v.transpose(1, 2) # [B, H, T, D]
|
158 |
+
|
159 |
+
# Apply Scaled Dot-Product Attention
|
160 |
+
attn_out = F.scaled_dot_product_attention(q, k, v,
|
161 |
+
dropout_p=0.0,
|
162 |
+
is_causal=True,
|
163 |
+
attn_mask=attn_mask)
|
164 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C) # [B, T, C]
|
165 |
+
return self.o_proj(attn_out) # [B, T, C]
|
166 |
+
|
167 |
+
class DeepSeekExpertLayer(nn.Module):
|
168 |
+
def __init__(self, hidden_size, intermediate_size):
|
169 |
+
super().__init__()
|
170 |
+
self.hidden_size = hidden_size
|
171 |
+
self.intermediate_size = intermediate_size
|
172 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
173 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
174 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
175 |
+
self.act_fn = SiLU()
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
179 |
+
|
180 |
+
class DeepSeekMOE(nn.Module):
|
181 |
+
"""
|
182 |
+
A Mixture of Experts (MoE) layer that routes input through a set of expert layers.
|
183 |
+
|
184 |
+
This class implements a mixture of experts mechanism where a subset of experts is selected
|
185 |
+
for each input token based on learned routing logits. The output is a combination of the
|
186 |
+
shared experts and the routed experts, allowing for efficient computation and increased
|
187 |
+
model capacity.
|
188 |
+
|
189 |
+
Attributes:
|
190 |
+
hidden_size (int): The size of the hidden layer.
|
191 |
+
intermediate_size (int): The size of the intermediate layer.
|
192 |
+
num_experts (int): Total number of experts available.
|
193 |
+
num_shared_experts (int): Number of shared experts that are used for all inputs.
|
194 |
+
top_k (int): The number of top experts to route each input to.
|
195 |
+
shared_experts (nn.ModuleList): List of shared expert layers.
|
196 |
+
routed_experts (nn.ModuleList): List of routed expert layers.
|
197 |
+
routing_fn (nn.Linear): Linear layer for computing routing logits.
|
198 |
+
routing_bias (nn.Parameter): Bias for the routing logits.
|
199 |
+
|
200 |
+
Methods:
|
201 |
+
forward(x): Forward pass through the MoE layer, routing input through selected experts.
|
202 |
+
"""
|
203 |
+
def __init__(self, hidden_size, intermediate_size, num_experts, num_shared_experts, top_k):
|
204 |
+
super().__init__()
|
205 |
+
self.hidden_size = hidden_size
|
206 |
+
self.intermediate_size = intermediate_size
|
207 |
+
self.num_experts = num_experts
|
208 |
+
self.num_shared_experts = num_shared_experts
|
209 |
+
self.top_k = top_k
|
210 |
+
self.num_routed_experts = num_experts - num_shared_experts
|
211 |
+
self.shared_experts = nn.ModuleList(
|
212 |
+
[DeepSeekExpertLayer(self.hidden_size, self.intermediate_size) for _ in range(self.num_shared_experts)]
|
213 |
+
)
|
214 |
+
self.routed_experts = nn.ModuleList(
|
215 |
+
[DeepSeekExpertLayer(self.hidden_size, self.intermediate_size) for _ in range(self.num_routed_experts)]
|
216 |
+
)
|
217 |
+
|
218 |
+
# Routing Function
|
219 |
+
self.routing_fn = nn.Linear(self.hidden_size, self.num_routed_experts, bias=False)
|
220 |
+
self.routing_bias = nn.Parameter(torch.zeros(self.num_routed_experts))
|
221 |
+
def forward(self, x):
|
222 |
+
B, T, C = x.size()
|
223 |
+
shared_out = sum(expert(x) for expert in self.shared_experts)
|
224 |
+
if self.num_shared_experts>1:
|
225 |
+
shared_out = shared_out/self.num_shared_experts # normalize the shared experts
|
226 |
+
# calculate the routing function
|
227 |
+
routing_logits = self.routing_fn(x) + self.routing_bias # [B, T, num_routed_experts]
|
228 |
+
# GEt Topk Experts per token
|
229 |
+
routing_probs = torch.sigmoid(routing_logits) # [B, T, num_routed_experts]
|
230 |
+
scores, indices = torch.topk(routing_probs, self.top_k, dim=-1) # [B, T, top_k]
|
231 |
+
# normalize the top k scores
|
232 |
+
scores = scores/torch.sum(scores, dim=-1, keepdim=True)
|
233 |
+
# process the routed experts
|
234 |
+
#combined_output = torch.zeros(B, T, C, device=x.device)
|
235 |
+
combined_output = torch.zeros_like(x)
|
236 |
+
|
237 |
+
# Calculate expert load for all experts
|
238 |
+
expert_load = torch.zeros(self.num_routed_experts, device=x.device)
|
239 |
+
for i in range(self.top_k):
|
240 |
+
expert_idx = indices[:, :, i] # [B, T, top_k]
|
241 |
+
|
242 |
+
expert_scores = scores[...,i:i+1]
|
243 |
+
# process the routed experts
|
244 |
+
for j in range(self.num_routed_experts):
|
245 |
+
mask = (expert_idx == j) # [B, T, 1]
|
246 |
+
if mask.any():
|
247 |
+
# Track expert usage (load)
|
248 |
+
expert_load[j] += mask.sum().float() / (B * T * self.top_k)
|
249 |
+
# Process tokens through this expert
|
250 |
+
expert_input = x[mask] # [B, T, 1, C]
|
251 |
+
expert_output = self.routed_experts[j](expert_input)
|
252 |
+
combined_output[mask] += expert_scores[mask] * expert_output
|
253 |
+
final_output = shared_out + combined_output
|
254 |
+
router_z_loss = self.update_bias_terms(expert_load)
|
255 |
+
return final_output, router_z_loss
|
256 |
+
|
257 |
+
def update_bias_terms(self, expert_load, router_z_loss_coef=0.001):
|
258 |
+
# Balance expert routing by adjusting the bias terms
|
259 |
+
# Target load is uniform distribution across experts
|
260 |
+
target_load = 1.0 / self.num_routed_experts
|
261 |
+
|
262 |
+
# Calculate load imbalance for each expert
|
263 |
+
load_diff = expert_load - target_load
|
264 |
+
|
265 |
+
# Dynamic update rate based on the magnitude of imbalance
|
266 |
+
# Larger imbalances get larger corrections
|
267 |
+
update_rate = 0.1 * torch.abs(load_diff)
|
268 |
+
|
269 |
+
# Update the routing bias to counteract imbalance
|
270 |
+
# Decrease bias for overutilized experts, increase for underutilized
|
271 |
+
self.routing_bias.data -= update_rate * load_diff
|
272 |
+
|
273 |
+
# Calculate the router z-loss to discourage extreme routing probabilities
|
274 |
+
# This helps stabilize training without auxiliary losses
|
275 |
+
# Z-loss encourages routing probabilities to stay away from 0 and 1
|
276 |
+
router_z_loss = router_z_loss_coef * torch.mean(torch.log(torch.sum(
|
277 |
+
torch.exp(self.routing_fn.weight), dim=-1)))
|
278 |
+
|
279 |
+
return router_z_loss
|
280 |
+
|
281 |
+
def update_bias_terms_old(self, expert_load, ):
|
282 |
+
# adjust the bias terms based on the expert load
|
283 |
+
target_load = 1/self.num_experts
|
284 |
+
load_diff = expert_load - target_load
|
285 |
+
# dyanamic update the bias based on the load imbalance
|
286 |
+
update_rate = 0.1 * torch.abs(load_diff)
|
287 |
+
# dyanmic update the bias terms using update rate
|
288 |
+
self.routing_bias = self.routing_bias - update_rate * load_diff
|
289 |
+
|
290 |
+
# for i in range(self.num_routed_experts):
|
291 |
+
# if expert_load[i] < target_load:
|
292 |
+
# self.routing_bias[i] -= 1
|
293 |
+
# else:
|
294 |
+
# self.routing_bias[i] += 1
|
295 |
+
class LlamaMLP(nn.Module):
|
296 |
+
"""
|
297 |
+
(mlp): LlamaMLP(
|
298 |
+
(moe): DeepSeekMOE(
|
299 |
+
(shared_experts): ModuleList(
|
300 |
+
(0): DeepSeekExpertLayer(
|
301 |
+
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
|
302 |
+
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
|
303 |
+
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
|
304 |
+
(act_fn): SiLU()
|
305 |
+
)
|
306 |
+
)
|
307 |
+
(routed_experts): ModuleList(
|
308 |
+
(0-2): 3 x DeepSeekExpertLayer(
|
309 |
+
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
|
310 |
+
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
|
311 |
+
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
|
312 |
+
(act_fn): SiLU()
|
313 |
+
)
|
314 |
+
)
|
315 |
+
(routing_fn): Linear(in_features=576, out_features=3, bias=False)
|
316 |
+
)
|
317 |
+
)
|
318 |
+
"""
|
319 |
+
def __init__(self, config):
|
320 |
+
super().__init__()
|
321 |
+
self.config = config
|
322 |
+
self.moe = DeepSeekMOE(hidden_size=config['hidden_size'],
|
323 |
+
intermediate_size=config['intermediate_size'],
|
324 |
+
num_experts=config['num_experts'],
|
325 |
+
num_shared_experts= config['num_shared_experts'],
|
326 |
+
top_k=config['top_k'])
|
327 |
+
# self.gate_proj = nn.Linear(self.config['hidden_size'], self.config['intermediate_size'], bias=False)
|
328 |
+
# self.up_proj = nn.Linear(self.config['hidden_size'], self.config['intermediate_size'], bias=False)
|
329 |
+
# self.down_proj = nn.Linear(self.config['intermediate_size'], self.config['hidden_size'], bias=False)
|
330 |
+
# self.act_fn = SiLU()
|
331 |
+
def forward(self, x):
|
332 |
+
output, router_z_loss = self.moe(x)
|
333 |
+
return output, router_z_loss
|
334 |
+
# gate = self.gate_proj(x)
|
335 |
+
# up = self.up_proj(x)
|
336 |
+
# down = self.down_proj(self.act_fn(gate)*up)
|
337 |
+
# return down
|
338 |
+
|
339 |
+
class LlamaRMSNorm(nn.Module):
|
340 |
+
"""
|
341 |
+
(norm): LlamaRMSNorm((576,), eps=1e-05)
|
342 |
+
# RMSNorm Formula:
|
343 |
+
# RMS(x) = sqrt((1 / d) * sum(x_i^2 for i in range(d)))
|
344 |
+
# x_normalized = x / RMS(x)
|
345 |
+
# output = gamma * x_normalized
|
346 |
+
|
347 |
+
"""
|
348 |
+
def __init__(self, config):
|
349 |
+
super().__init__()
|
350 |
+
self.config = config
|
351 |
+
self.eps = self.config['rms_norm_eps']
|
352 |
+
self.weight = nn.Parameter(torch.ones(self.config['hidden_size']))
|
353 |
+
def forward(self, x):
|
354 |
+
rms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
355 |
+
return self.weight *rms * x
|
356 |
+
|
357 |
+
class LlamaDecoderLayer(nn.Module):
|
358 |
+
def __init__(self, config):
|
359 |
+
super().__init__()
|
360 |
+
self.config = config
|
361 |
+
self.self_attn = MultiHeadLatentAttention(self.config)
|
362 |
+
self.input_layernorm = LlamaRMSNorm(self.config)
|
363 |
+
self.mlp = LlamaMLP(self.config)
|
364 |
+
self.post_attention_layernorm = LlamaRMSNorm(self.config)
|
365 |
+
|
366 |
+
def forward(self, x):
|
367 |
+
residual = x
|
368 |
+
x = self.input_layernorm(x)
|
369 |
+
x = self.self_attn(x)
|
370 |
+
x = x + residual
|
371 |
+
residual = x
|
372 |
+
x = self.post_attention_layernorm(x)
|
373 |
+
x, router_z_loss = self.mlp(x)
|
374 |
+
x = x + residual
|
375 |
+
return x, router_z_loss
|
376 |
+
|
377 |
+
class DeepSeekV3Model(nn.Module):
|
378 |
+
def __init__(self, config):
|
379 |
+
super().__init__()
|
380 |
+
self.init_method = config['init_method']
|
381 |
+
self.config = config['model_config']
|
382 |
+
self.embed_tokens = nn.Embedding(self.config['vocab_size'], self.config['hidden_size'])
|
383 |
+
self.rotary_emb = RotaryPositionalEmbedding(self.config['hidden_size'], self.config['rope_theta'])
|
384 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(self.config) for _ in range(self.config['num_hidden_layers'])])
|
385 |
+
self.norm = LlamaRMSNorm(self.config)
|
386 |
+
self.lm_head = nn.Linear(self.config['hidden_size'], self.config['vocab_size'], bias=False)
|
387 |
+
|
388 |
+
if self.config['tie_word_embeddings']:
|
389 |
+
self.lm_head.weight = self.embed_tokens.weight
|
390 |
+
|
391 |
+
self.apply(lambda m: _init_weights(m, self.init_method['std']))
|
392 |
+
|
393 |
+
def forward(self, x, y=None):
|
394 |
+
x = self.embed_tokens(x)
|
395 |
+
total_router_z_loss = 0.0
|
396 |
+
for layer in self.layers:
|
397 |
+
x, router_z_loss = layer(x)
|
398 |
+
total_router_z_loss += router_z_loss
|
399 |
+
x = self.norm(x)
|
400 |
+
logits = self.lm_head(x) # B,T,V
|
401 |
+
logits = logits.view(-1, logits.size(-1)) # Shape: [B*T, V] # 20, 49152
|
402 |
+
if y is not None:
|
403 |
+
y = y.view(-1) # Shape: [B*T] # 20
|
404 |
+
ce_loss = torch.nn.functional.cross_entropy(logits, y)
|
405 |
+
# Combine CE loss with router z-loss
|
406 |
+
loss = ce_loss + total_router_z_loss
|
407 |
+
return logits, loss
|
408 |
+
else:
|
409 |
+
return logits, None
|
410 |
+
|
411 |
+
|
412 |
+
def generate(self, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
|
413 |
+
model = self.to(device)
|
414 |
+
idx = idx.to(device)
|
415 |
+
model.eval()
|
416 |
+
for _ in range(max_new_tokens):
|
417 |
+
idx_cond = idx[:, -context_length:]
|
418 |
+
with torch.no_grad():
|
419 |
+
logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
|
420 |
+
logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
|
421 |
+
|
422 |
+
# Get the logits for the last token only
|
423 |
+
logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
|
424 |
+
|
425 |
+
if top_k is not None:
|
426 |
+
# top k sampling
|
427 |
+
top_logits, top_pos = torch.topk(logits, top_k)
|
428 |
+
min_logit = top_logits[:, -1].unsqueeze(-1)
|
429 |
+
logits = torch.where(logits < min_logit,
|
430 |
+
torch.tensor(float('-inf')).to(logits.device),
|
431 |
+
logits)
|
432 |
+
|
433 |
+
# temperature scaling
|
434 |
+
if temperature > 0.0:
|
435 |
+
logits /= temperature
|
436 |
+
probs = torch.softmax(logits, dim=-1)
|
437 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
438 |
+
else:
|
439 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
440 |
+
|
441 |
+
if idx_next.item() == eos_token:
|
442 |
+
break
|
443 |
+
|
444 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
445 |
+
model.train()
|
446 |
+
return idx
|
447 |
+
|
448 |
+
# if __name__ == "__main__":
|
449 |
+
# torch.manual_seed(0)
|
450 |
+
# config = yaml.load(open("config_smollm2_135M.yaml", "r"), Loader=yaml.FullLoader)
|
451 |
+
# print(config.keys())
|
452 |
+
# model_config = config['model']['model_config']
|
453 |
+
# print(model_config)
|
454 |
+
# model = DeepSeekV3Model(config['model'])
|
455 |
+
# x_tokens = torch.randint(0, model_config['vocab_size'], (1, 10)) # Generate random token indices
|
456 |
+
# print(model(x_tokens).shape)
|
457 |
+
# total_params = sum(p.numel() for p in model.parameters())
|
458 |
+
# print(f"Total parameters: {total_params}") #134515008
|
459 |
+
# print(model)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchtext
|
3 |
+
pandas
|
4 |
+
numpy==1.26.1
|
5 |
+
matplotlib
|
6 |
+
tiktoken
|
7 |
+
tensorflow>=2.15.0
|
8 |
+
tqdm
|
9 |
+
# urllib
|
10 |
+
requests
|
11 |
+
boto3
|
12 |
+
datasets
|
13 |
+
transformers
|
14 |
+
gradio
|
train.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deepseek_v3 import DeepSeekV3Model
|
2 |
+
import torch
|
3 |
+
import yaml
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
# from gptdataloader import GPTDataLoader
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import numpy as np
|
8 |
+
from datasets import load_dataset
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
|
12 |
+
from utils import upload_file_to_s3
|
13 |
+
# At the start of training loop
|
14 |
+
# print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
15 |
+
# print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
20 |
+
file_handler = logging.FileHandler('training.log')
|
21 |
+
file_handler.setFormatter(formatter) # Set formatter on the handler, not the logger
|
22 |
+
logger.addHandler(file_handler)
|
23 |
+
logger.setLevel(logging.INFO)
|
24 |
+
|
25 |
+
def encode_text(examples, tokenizer, seq_length):
|
26 |
+
"""Tokenize and prepare text examples for training."""
|
27 |
+
tokens = tokenizer(
|
28 |
+
examples["text"],
|
29 |
+
truncation=True,
|
30 |
+
padding="max_length",
|
31 |
+
max_length=seq_length + 1,
|
32 |
+
return_tensors="pt",
|
33 |
+
)
|
34 |
+
# Use clone().detach() as recommended
|
35 |
+
input_ids = tokens["input_ids"].squeeze(0).clone().detach()
|
36 |
+
input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1)
|
37 |
+
labels = input_ids.clone().detach()
|
38 |
+
labels = labels[1:].to(torch.int64)
|
39 |
+
input_ids = input_ids[:-1].to(torch.int64)
|
40 |
+
|
41 |
+
return {"input_ids": input_ids, "labels": labels}
|
42 |
+
|
43 |
+
def load_cosmopedia_dataset(batch_size=8, seq_length=1024, tokenizer=None):
|
44 |
+
"""
|
45 |
+
Returns a torch dataloader for the cosmopedia dataset
|
46 |
+
"""
|
47 |
+
# Set tokenizer parallelism explicitly
|
48 |
+
import os
|
49 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
50 |
+
logger.info("tokenizer parallelism set to false")
|
51 |
+
try:
|
52 |
+
# Increase timeout and retries for dataset loading
|
53 |
+
from datasets import config
|
54 |
+
config.HF_DATASETS_TIMEOUT = 300 # 5 minutes timeout
|
55 |
+
config.MAX_RETRIES = 10 # Increase retry attempts
|
56 |
+
logger.info("dataset loading config set")
|
57 |
+
train_dataset = load_dataset(
|
58 |
+
"HuggingFaceTB/smollm-corpus",
|
59 |
+
name="cosmopedia-v2",
|
60 |
+
split="train",
|
61 |
+
streaming=True,
|
62 |
+
)
|
63 |
+
logger.info("dataset loaded")
|
64 |
+
|
65 |
+
# Use partial to bind tokenizer and seq_length to the encode function
|
66 |
+
from functools import partial
|
67 |
+
encode_fn = partial(encode_text, tokenizer=tokenizer, seq_length=seq_length)
|
68 |
+
|
69 |
+
train_dataset = train_dataset.map(
|
70 |
+
encode_fn,
|
71 |
+
remove_columns=["text"],
|
72 |
+
batched=False
|
73 |
+
)
|
74 |
+
train_dataset = train_dataset.with_format("torch")
|
75 |
+
|
76 |
+
train_dataloader = DataLoader(
|
77 |
+
train_dataset,
|
78 |
+
batch_size=batch_size,
|
79 |
+
num_workers=2,
|
80 |
+
pin_memory=True,
|
81 |
+
prefetch_factor=4,
|
82 |
+
persistent_workers=True
|
83 |
+
)
|
84 |
+
return train_dataloader
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Error loading dataset: {str(e)}")
|
87 |
+
return None
|
88 |
+
|
89 |
+
# def create_dataloader(file_path, tokenizer, context_size, stride):
|
90 |
+
# with open(file_path, "r") as file:
|
91 |
+
# text_data = file.read()
|
92 |
+
# total_characters = len(text_data)
|
93 |
+
# total_tokens = len(tokenizer.encode(text_data))
|
94 |
+
# logger.info(f"Characters: {total_characters}")
|
95 |
+
# logger.info(f"Tokens: {total_tokens}")
|
96 |
+
|
97 |
+
# # create dataloader
|
98 |
+
# train_ratio = 0.9
|
99 |
+
# val_ratio = 0.1
|
100 |
+
# split_idx = int(train_ratio * total_characters)
|
101 |
+
|
102 |
+
# train_data = text_data[:split_idx]
|
103 |
+
|
104 |
+
# valid_data = text_data[split_idx:]
|
105 |
+
|
106 |
+
# train_dataset = GPTDataLoader(train_data, tokenizer, context_size, stride)
|
107 |
+
# valid_dataset = GPTDataLoader(valid_data, tokenizer, context_size, stride)
|
108 |
+
# return DataLoader(train_dataset, batch_size=10, shuffle=True, drop_last=True), DataLoader(valid_dataset, batch_size=10, shuffle=False, drop_last=True)
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
# def calculate_loss_batch(input_batch, target_batch, model, device):
|
114 |
+
# input_batch = input_batch.to(device)
|
115 |
+
# target_batch = target_batch.to(device)
|
116 |
+
|
117 |
+
# logits, loss = model(input_batch, target_batch) # e.g. 10, 32, 49152
|
118 |
+
# logits = logits.view(-1, logits.size(-1)) # Shape: [320, 49152]
|
119 |
+
# target_batch = target_batch.view(-1) # Shape: [320]
|
120 |
+
# loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
121 |
+
# return loss
|
122 |
+
|
123 |
+
# def calc_loss_loader(data_loader, model, device, num_batches=None):
|
124 |
+
# total_loss = 0.0
|
125 |
+
# if len(data_loader) == 0:
|
126 |
+
# return float("nan")
|
127 |
+
# elif num_batches is None:
|
128 |
+
# num_batches = len(data_loader)
|
129 |
+
# else:
|
130 |
+
# num_batches = min(num_batches, len(data_loader))
|
131 |
+
# for i, (input_batch, target_batch) in enumerate(data_loader):
|
132 |
+
# if i < num_batches:
|
133 |
+
# loss = calculate_loss_batch(input_batch, target_batch, model, device)
|
134 |
+
# total_loss += loss.item()
|
135 |
+
# else:
|
136 |
+
# break
|
137 |
+
# return total_loss / num_batches
|
138 |
+
|
139 |
+
# def evaluate_model(model, train_dataloader, valid_dataloader, device, eval_iter=100):
|
140 |
+
# model.eval()
|
141 |
+
# with torch.no_grad():
|
142 |
+
# train_loss = calc_loss_loader(train_dataloader, model, device, num_batches=eval_iter)
|
143 |
+
# valid_loss = calc_loss_loader(valid_dataloader, model, device, num_batches=eval_iter)
|
144 |
+
# model.train()
|
145 |
+
# return train_loss, valid_loss
|
146 |
+
|
147 |
+
def generate(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None):
|
148 |
+
logger.info(f"Generating on device {device}")
|
149 |
+
model = model.to(device)
|
150 |
+
idx = idx.to(device)
|
151 |
+
model.eval()
|
152 |
+
for _ in range(max_new_tokens):
|
153 |
+
idx_cond = idx[:, -context_length:]
|
154 |
+
with torch.no_grad():
|
155 |
+
logits, _ = model(idx_cond) # Unpack both logits and loss (ignore loss)
|
156 |
+
logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) # Reshape to [batch, seq, vocab]
|
157 |
+
|
158 |
+
# Get the logits for the last token only
|
159 |
+
logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]
|
160 |
+
|
161 |
+
if top_k is not None:
|
162 |
+
# top k sampling
|
163 |
+
top_logits, top_pos = torch.topk(logits, top_k)
|
164 |
+
min_logit = top_logits[:, -1].unsqueeze(-1)
|
165 |
+
logits = torch.where(logits < min_logit,
|
166 |
+
torch.tensor(float('-inf')).to(logits.device),
|
167 |
+
logits)
|
168 |
+
|
169 |
+
# temperature scaling
|
170 |
+
if temperature > 0.0:
|
171 |
+
logits /= temperature
|
172 |
+
probs = torch.softmax(logits, dim=-1)
|
173 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
174 |
+
else:
|
175 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
176 |
+
|
177 |
+
if idx_next.item() == eos_token:
|
178 |
+
break
|
179 |
+
|
180 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
181 |
+
model.train()
|
182 |
+
return idx
|
183 |
+
|
184 |
+
def sync_device(device):
|
185 |
+
if device.startswith('cuda'):
|
186 |
+
torch.cuda.synchronize()
|
187 |
+
elif device == 'cpu':
|
188 |
+
torch.cpu.synchronize() if hasattr(torch.cpu, 'synchronize') else None
|
189 |
+
elif device.startswith('mps'): # For Apple Silicon
|
190 |
+
torch.mps.synchronize()
|
191 |
+
|
192 |
+
def print_gpu_memory(step_name=""):
|
193 |
+
"""
|
194 |
+
Print GPU memory statistics with a specified step name
|
195 |
+
"""
|
196 |
+
if torch.cuda.is_available():
|
197 |
+
logger.info(f"\nGPU Memory Stats {step_name}:")
|
198 |
+
logger.info(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
199 |
+
logger.info(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
200 |
+
logger.info(f"Max GPU Memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
201 |
+
|
202 |
+
# Learning rate scheduler
|
203 |
+
def get_lr_lambda(current_step, warmup_steps, max_steps, max_lr):
|
204 |
+
"""
|
205 |
+
Modified learning rate scheduler with:
|
206 |
+
1. Linear warmup for first 3000 steps
|
207 |
+
2. Cosine decay from 3000 to 60000 steps
|
208 |
+
3. Minimum learning rate of 1.5e-5 (5% of max_lr)
|
209 |
+
"""
|
210 |
+
min_lr = max_lr * 0.05 # Minimum learning rate (5% of max_lr)
|
211 |
+
|
212 |
+
if current_step < warmup_steps:
|
213 |
+
# Linear warmup from 0 to max_lr
|
214 |
+
return float(current_step) / float(max(1, warmup_steps))
|
215 |
+
else:
|
216 |
+
# Cosine decay from max_lr to min_lr
|
217 |
+
progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
218 |
+
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))
|
219 |
+
|
220 |
+
|
221 |
+
def train_model(config, model, train_loader, test_loader, optimizer, device, num_epochs, eval_freq, eval_iter, start_context="Jack Gisburn rather a cheap genius- ", tokenizer=None):
|
222 |
+
total_loss = 0
|
223 |
+
tokens_seen, global_step = 0, -1
|
224 |
+
|
225 |
+
# Adjusted gradient accumulation setup for batch size 8
|
226 |
+
actual_batch_size = config['tokens']['micro_batch_size'] # Now 8
|
227 |
+
effective_batch_size_multiplier = 1 # Adjusted for batch size 8
|
228 |
+
target_batch_size = effective_batch_size_multiplier * config['tokens']['micro_batch_size']
|
229 |
+
gradient_accumulation_steps = target_batch_size // actual_batch_size
|
230 |
+
|
231 |
+
# Learning rate parameters adjusted for batch size 8
|
232 |
+
max_lr = 3e-4 # Keep the same max learning rate
|
233 |
+
warmup_steps = 3000 # Keep warmup steps
|
234 |
+
max_steps = 60000 # Keep max steps
|
235 |
+
min_lr = max_lr * 0.05 # Keep minimum LR at 5% of max
|
236 |
+
|
237 |
+
# Create LambdaLR scheduler with the improved lambda function
|
238 |
+
lr_lambda = lambda step: get_lr_lambda(step, warmup_steps, max_steps, max_lr)
|
239 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
240 |
+
|
241 |
+
logger.info(f"Training with learning rate schedule:")
|
242 |
+
logger.info(f"Max LR: {max_lr}")
|
243 |
+
logger.info(f"Warmup Steps: {warmup_steps}")
|
244 |
+
logger.info(f"Max Steps: {max_steps}")
|
245 |
+
logger.info(f"Min LR: {max_lr * 0.05}")
|
246 |
+
logger.info(f"Gradient Accumulation Steps: {gradient_accumulation_steps}")
|
247 |
+
logger.info(f"Effective Batch Size: {actual_batch_size * gradient_accumulation_steps}")
|
248 |
+
|
249 |
+
print_gpu_memory("at start of training")
|
250 |
+
|
251 |
+
# Add these near the start of training loop
|
252 |
+
torch.cuda.empty_cache()
|
253 |
+
torch.backends.cudnn.benchmark = True
|
254 |
+
for epoch in range(num_epochs):
|
255 |
+
model.train()
|
256 |
+
optimizer.zero_grad() # Zero gradients at start of epoch
|
257 |
+
|
258 |
+
for batch_idx, batch in enumerate(train_loader):
|
259 |
+
input_batch = batch['input_ids'].to(device)
|
260 |
+
target_batch = batch['labels'].to(device)
|
261 |
+
|
262 |
+
# Forward pass
|
263 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
264 |
+
logits, original_loss = model(input_batch, target_batch)
|
265 |
+
|
266 |
+
# Scale loss for gradient accumulation
|
267 |
+
scaled_loss = original_loss / gradient_accumulation_steps
|
268 |
+
scaled_loss.backward()
|
269 |
+
|
270 |
+
# Add the original loss to total_loss for logging
|
271 |
+
total_loss += original_loss.item() # Don't multiply back up
|
272 |
+
tokens_seen += input_batch.numel()
|
273 |
+
|
274 |
+
# Calculate running average loss
|
275 |
+
total_batches = batch_idx + 1
|
276 |
+
avg_loss = total_loss / total_batches
|
277 |
+
if batch_idx % 25 == 0:
|
278 |
+
logger.info(f"Batch {batch_idx + 1}, Running Avg Loss: {avg_loss:.5f}")
|
279 |
+
# Only update weights after accumulating gradients
|
280 |
+
if (batch_idx + 1) % gradient_accumulation_steps == 0:
|
281 |
+
# Gradient clipping
|
282 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
283 |
+
|
284 |
+
optimizer.step()
|
285 |
+
scheduler.step() # Update learning rate
|
286 |
+
optimizer.zero_grad()
|
287 |
+
global_step += 1
|
288 |
+
|
289 |
+
# Evaluation block
|
290 |
+
if global_step % eval_freq == 0 and global_step > 0:
|
291 |
+
# Use total batches processed instead of global_step
|
292 |
+
current_lr = scheduler.get_last_lr()[0]
|
293 |
+
optimizer_lr = optimizer.param_groups[0]['lr']
|
294 |
+
|
295 |
+
print_gpu_memory(f"at step {global_step}")
|
296 |
+
logger.info(f"learning rate: {current_lr:.8f}")
|
297 |
+
logger.info(f"Ep {epoch+1} (Step {global_step:06d}): "
|
298 |
+
f"Avg loss {avg_loss:.3f} | {tokens_seen} tokens seen")
|
299 |
+
logger.info(f"optimizer lr: {optimizer_lr:.8f}")
|
300 |
+
logger.info(f"scheduler lr: {current_lr:.8f}")
|
301 |
+
|
302 |
+
# Generate sample text
|
303 |
+
start_context_list = ["In today's ever-evolving world, technology has become an integral part of our lives","Once upon a time, there was a friendly agency called Gaudette Insurance Agency, Inc. They help","A couple of years ago, I was working as an extra on the set of a low-budget British film.","Introduction: The Art of Crafting Vegan Sandwich Delights Sandwiches occupy a unique space in","Meet Chris, a superhero of supplies! Just like how Batman protects Gotham City","Identity formation is a complex and multifaceted process that involves the development of", "With the development of science and technology, computer has become more and more ","Just as there are many variants and forms of electronic malware and Internet-based ","Correctly identifying what is causing a problem is the most important step in pest control.","Lobster, California spiny The California Spiny Lobster fishery is a small but locally ","Bees are vital for pollination. You can buy leafcutter bee houses to attract ","Feeling Alone Together: Exploring Alienation and Isolation in Literature", "Imagine if someone got their hands on dangerous weapons","Once upon a time, in a colorful town called Popville, ","he bell above the door jangled as Sarah walked into her family's hardware store"]
|
304 |
+
# Randomly select a prompt from the list
|
305 |
+
random_prompt = np.random.choice(start_context_list)
|
306 |
+
logger.info(f"Selected prompt: {random_prompt}")
|
307 |
+
logger.info(f"+++"*30)
|
308 |
+
encoded_text = tokenizer.encode(random_prompt, return_tensors="pt")
|
309 |
+
random_topk = np.random.randint(1, 10)
|
310 |
+
logger.info(f"random_topk: {random_topk}")
|
311 |
+
random_temperature = np.random.uniform(0.7, 0.9)
|
312 |
+
logger.info(f"random_temperature: {random_temperature}")
|
313 |
+
logger.info(f"global step {global_step} , batch_idx {batch_idx} => generating text")
|
314 |
+
generated_text = generate(model,
|
315 |
+
idx=encoded_text,
|
316 |
+
max_new_tokens=256,
|
317 |
+
context_length=256,
|
318 |
+
temperature=random_temperature,
|
319 |
+
top_k=random_topk,
|
320 |
+
eos_token=tokenizer.eos_token_id,
|
321 |
+
device=device)
|
322 |
+
logger.info(f"+++"*30)
|
323 |
+
logger.info(tokenizer.decode(generated_text.squeeze(0)))
|
324 |
+
logger.info(f"+++"*30)
|
325 |
+
|
326 |
+
# Save checkpoint
|
327 |
+
model_file_name = f"model_{global_step}_steps_avg_loss_{avg_loss:.5f}_optimizer_lr_{optimizer_lr:.8f}.pth"
|
328 |
+
torch.save({
|
329 |
+
'step': global_step,
|
330 |
+
'model_state_dict': model.state_dict(),
|
331 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
332 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
333 |
+
'loss': avg_loss,
|
334 |
+
}, model_file_name)
|
335 |
+
|
336 |
+
s3_path = upload_file_to_s3(model_file_name, config['model']['model_config']['s3_bucket'],
|
337 |
+
config['model']['model_config']['s3_checkpoint_folder'])
|
338 |
+
logger.info(f"Model saved to S3: {s3_path}")
|
339 |
+
|
340 |
+
log_path = upload_file_to_s3(config['model']['model_config']['s3_log_file_name'], config['model']['model_config']['s3_bucket'],
|
341 |
+
config['model']['model_config']['s3_log_folder'])
|
342 |
+
logger.info(f"Log saved to S3: {log_path}")
|
343 |
+
|
344 |
+
if batch_idx % 100 == 0:
|
345 |
+
logger.info(f"Batch {batch_idx} finished")
|
346 |
+
logger.info(f"+++"*30)
|
347 |
+
|
348 |
+
logger.info("Training complete")
|
349 |
+
|
350 |
+
if __name__ == "__main__":
|
351 |
+
config = yaml.load(open("config_smollm2_135M.yaml", "r"), Loader=yaml.FullLoader)
|
352 |
+
logger.info(config)
|
353 |
+
|
354 |
+
# Set memory efficient settings
|
355 |
+
torch.set_float32_matmul_precision('high')
|
356 |
+
torch.backends.cudnn.benchmark = True
|
357 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
358 |
+
|
359 |
+
# Empty cache before model creation
|
360 |
+
torch.cuda.empty_cache()
|
361 |
+
import os
|
362 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64'
|
363 |
+
|
364 |
+
model = DeepSeekV3Model(config['model'])
|
365 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
366 |
+
|
367 |
+
# Enable gradient checkpointing for memory efficiency
|
368 |
+
# model.gradient_checkpointing_enable()
|
369 |
+
|
370 |
+
model.to(device)
|
371 |
+
#model = torch.compile(model)
|
372 |
+
logger.info(model)
|
373 |
+
logger.info("++"*30)
|
374 |
+
total_params = sum(p.numel() for p in model.parameters())
|
375 |
+
logger.info(f"Total parameters: {total_params}")
|
376 |
+
|
377 |
+
optimizer = torch.optim.AdamW(
|
378 |
+
model.parameters(),
|
379 |
+
lr=3e-4,
|
380 |
+
weight_decay=0.15,
|
381 |
+
betas=(0.9, 0.95)
|
382 |
+
)
|
383 |
+
|
384 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
385 |
+
tokenizer.pad_token = tokenizer.eos_token
|
386 |
+
vocab_size = tokenizer.vocab_size
|
387 |
+
|
388 |
+
# Adjusted batch size to 8
|
389 |
+
train_loader = load_cosmopedia_dataset(
|
390 |
+
batch_size=8, # Changed from 4 to 8
|
391 |
+
seq_length=512, # Kept at 512
|
392 |
+
tokenizer=tokenizer
|
393 |
+
)
|
394 |
+
|
395 |
+
import time
|
396 |
+
t1 = time.time()
|
397 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
398 |
+
|
399 |
+
# Set environment variable for memory allocation
|
400 |
+
import os
|
401 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
|
402 |
+
|
403 |
+
train_model(
|
404 |
+
config,
|
405 |
+
model,
|
406 |
+
train_loader,
|
407 |
+
train_loader,
|
408 |
+
optimizer=optimizer,
|
409 |
+
device=device,
|
410 |
+
num_epochs=1,
|
411 |
+
eval_freq=2500, # Increase eval frequency to every 500 steps
|
412 |
+
eval_iter=2500,
|
413 |
+
start_context="Once Upon a Time far far away in a galaxy",
|
414 |
+
tokenizer=tokenizer
|
415 |
+
)
|
416 |
+
t2 = time.time()
|
417 |
+
logger.info(f"Time taken for training: {t2 - t1:.2f} seconds")
|
utils.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
from boto3.s3.transfer import TransferConfig
|
3 |
+
from tqdm import tqdm
|
4 |
+
import os
|
5 |
+
|
6 |
+
def upload_file_to_s3(file_path, bucket_name, s3_prefix):
|
7 |
+
|
8 |
+
|
9 |
+
class ProgressPercentage(object):
|
10 |
+
def __init__(self, filename):
|
11 |
+
self._filename = filename
|
12 |
+
self._size = float(os.path.getsize(filename))
|
13 |
+
self._seen_so_far = 0
|
14 |
+
self._pbar = tqdm(total=self._size, unit='B', unit_scale=True, desc=f"Uploading {os.path.basename(filename)}")
|
15 |
+
|
16 |
+
def __call__(self, bytes_amount):
|
17 |
+
self._seen_so_far += bytes_amount
|
18 |
+
self._pbar.update(bytes_amount)
|
19 |
+
|
20 |
+
s3_client = boto3.client('s3')
|
21 |
+
file_name = os.path.basename(file_path)
|
22 |
+
s3_path = f"{s3_prefix}/{file_name}"
|
23 |
+
|
24 |
+
# Configure multipart upload
|
25 |
+
config = TransferConfig(
|
26 |
+
multipart_threshold=1024 * 25, # 25MB
|
27 |
+
max_concurrency=10,
|
28 |
+
multipart_chunksize=1024 * 25, # 25MB
|
29 |
+
use_threads=True
|
30 |
+
)
|
31 |
+
|
32 |
+
try:
|
33 |
+
s3_client.upload_file(
|
34 |
+
file_path,
|
35 |
+
bucket_name,
|
36 |
+
s3_path,
|
37 |
+
Config=config,
|
38 |
+
Callback=ProgressPercentage(file_path)
|
39 |
+
)
|
40 |
+
return f"s3://{bucket_name}/{s3_path}"
|
41 |
+
except Exception as e:
|
42 |
+
print(f"Failed to upload {file_path} to S3: {str(e)}")
|
43 |
+
return None
|
44 |
+
|
45 |
+
max_lr = 1e-3
|
46 |
+
warmup_steps = 10
|
47 |
+
max_steps = 25000
|
48 |
+
import math
|
49 |
+
def get_lr_lambda(current_step, warmup_steps, max_steps, max_lr):
|
50 |
+
"""
|
51 |
+
Learning rate scheduler with:
|
52 |
+
1. Linear warmup
|
53 |
+
2. Cosine decay
|
54 |
+
3. Minimum learning rate of 10% of max_lr
|
55 |
+
"""
|
56 |
+
min_lr = max_lr * 0.1 # Minimum learning rate (10% of max_lr)
|
57 |
+
|
58 |
+
if current_step < warmup_steps:
|
59 |
+
# Linear warmup
|
60 |
+
return max_lr * (current_step + 1) / warmup_steps
|
61 |
+
elif current_step > max_steps:
|
62 |
+
# After max_steps, return minimum learning rate
|
63 |
+
return min_lr
|
64 |
+
else:
|
65 |
+
# Cosine decay between warmup_steps and max_steps
|
66 |
+
decay_ratio = (current_step - warmup_steps) / (max_steps - warmup_steps)
|
67 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
68 |
+
return min_lr + coeff * (max_lr - min_lr)
|
69 |
+
|
70 |
+
|
71 |
+
def plot_lr_schedule():
|
72 |
+
"""
|
73 |
+
Helper function to visualize the learning rate schedule
|
74 |
+
"""
|
75 |
+
import matplotlib.pyplot as plt
|
76 |
+
steps = list(range(0, max_steps + 100))
|
77 |
+
lrs = [get_lr_lambda(step, warmup_steps, max_steps, max_lr) for step in steps]
|
78 |
+
|
79 |
+
plt.figure(figsize=(10, 5))
|
80 |
+
plt.plot(steps, lrs)
|
81 |
+
plt.title('Learning Rate Schedule')
|
82 |
+
plt.xlabel('Steps')
|
83 |
+
plt.ylabel('Learning Rate')
|
84 |
+
plt.grid(True)
|
85 |
+
plt.show()
|
86 |
+
|
87 |
+
def plot_training_loss(log_file_path, output_path=None):
|
88 |
+
"""
|
89 |
+
Parse a training log file and plot the running average loss against batch steps.
|
90 |
+
Also adds a trend line to visualize the overall training progress.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
log_file_path (str): Path to the training log file
|
94 |
+
output_path (str, optional): Path to save the plot as PNG. If None, displays the plot instead.
|
95 |
+
"""
|
96 |
+
import re
|
97 |
+
import matplotlib.pyplot as plt
|
98 |
+
import numpy as np
|
99 |
+
from scipy.optimize import curve_fit
|
100 |
+
|
101 |
+
# Regular expression to extract batch number and loss
|
102 |
+
pattern = r"Batch (\d+), Running Avg Loss: ([0-9.]+)"
|
103 |
+
|
104 |
+
steps = []
|
105 |
+
losses = []
|
106 |
+
|
107 |
+
# Read and parse the log file
|
108 |
+
with open(log_file_path, 'r') as file:
|
109 |
+
for line in file:
|
110 |
+
match = re.search(pattern, line)
|
111 |
+
if match:
|
112 |
+
batch_num = int(match.group(1))
|
113 |
+
loss = float(match.group(2))
|
114 |
+
steps.append(batch_num)
|
115 |
+
losses.append(loss)
|
116 |
+
|
117 |
+
if not steps:
|
118 |
+
print("No loss data found in the log file.")
|
119 |
+
return
|
120 |
+
|
121 |
+
# Create the plot
|
122 |
+
plt.figure(figsize=(12, 6))
|
123 |
+
plt.plot(steps, losses, 'b-', alpha=0.5, label='Running Avg Loss')
|
124 |
+
|
125 |
+
# Add trend line (using polynomial fit)
|
126 |
+
def poly_func(x, a, b, c):
|
127 |
+
return a * x**2 + b * x + c
|
128 |
+
|
129 |
+
# Convert to numpy arrays for curve fitting
|
130 |
+
x_array = np.array(steps)
|
131 |
+
y_array = np.array(losses)
|
132 |
+
|
133 |
+
# Fit the curve
|
134 |
+
try:
|
135 |
+
popt, _ = curve_fit(poly_func, x_array, y_array)
|
136 |
+
x_line = np.linspace(min(steps), max(steps), 1000)
|
137 |
+
y_line = poly_func(x_line, *popt)
|
138 |
+
plt.plot(x_line, y_line, 'r-', label='Trend Line')
|
139 |
+
except Exception as e:
|
140 |
+
print(f"Could not fit trend line: {e}")
|
141 |
+
# Fallback to simple moving average for trend
|
142 |
+
window_size = min(len(steps) // 10, 100) if len(steps) > 100 else len(steps) // 2
|
143 |
+
if window_size > 0:
|
144 |
+
moving_avg = np.convolve(y_array, np.ones(window_size)/window_size, mode='valid')
|
145 |
+
plt.plot(steps[window_size-1:], moving_avg, 'r-', label='Moving Average Trend')
|
146 |
+
|
147 |
+
# Add labels and title
|
148 |
+
plt.xlabel('Batch Number')
|
149 |
+
plt.ylabel('Running Average Loss')
|
150 |
+
plt.title('Training Loss Over Time')
|
151 |
+
plt.grid(True)
|
152 |
+
plt.legend()
|
153 |
+
|
154 |
+
# Add min and max loss annotations
|
155 |
+
min_loss = min(losses)
|
156 |
+
min_idx = losses.index(min_loss)
|
157 |
+
max_loss = max(losses)
|
158 |
+
max_idx = losses.index(max_loss)
|
159 |
+
|
160 |
+
plt.annotate(f'Min: {min_loss:.5f}',
|
161 |
+
xy=(steps[min_idx], min_loss),
|
162 |
+
xytext=(steps[min_idx], min_loss*1.05),
|
163 |
+
arrowprops=dict(facecolor='green', shrink=0.05),
|
164 |
+
fontsize=10)
|
165 |
+
|
166 |
+
plt.annotate(f'Max: {max_loss:.5f}',
|
167 |
+
xy=(steps[max_idx], max_loss),
|
168 |
+
xytext=(steps[max_idx], max_loss*0.95),
|
169 |
+
arrowprops=dict(facecolor='red', shrink=0.05),
|
170 |
+
fontsize=10)
|
171 |
+
|
172 |
+
# Save or show the plot
|
173 |
+
plt.tight_layout()
|
174 |
+
if output_path:
|
175 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
176 |
+
print(f"Plot saved to {output_path}")
|
177 |
+
else:
|
178 |
+
plt.show()
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
# plot_lr_schedule()
|
182 |
+
plot_training_loss("training.log", "train_loss.png")
|