Gemma_models
Collection
Gemma models 1B, 2B, sft etc. trained from scratch
•
4 items
•
Updated
June 22
Base model trained on 5B mostly cleaned kk data.
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"
# Загрузка твоей обученной модели
model_path = "SRP-base-model-training/gemma_3_2B_base_v1_kk_only_5B-data"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
prompt = ("Сәлем")
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(
**model_inputs,
max_new_tokens=50,
do_sample=True,
#temperature=0.2,
top_p=0.9, # nucleus sampling
#top_k=50, # отфильтровать всё, кроме 50 лучших токенов
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
generation = generation[0][input_len:]
decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)
Main script for training
# train_gemma_pretraining.py
import os, math, json, argparse, torch, wandb
from datasets import load_dataset, concatenate_datasets
from transformers import (
AutoTokenizer, Gemma3TextConfig, Gemma3ForCausalLM,
DataCollatorForLanguageModeling
)
from trl import SFTTrainer, SFTConfig
from pathlib import Path
# ─── arguments ───
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--max_seq_length", type=int, default=2048)
p.add_argument("--per_device_batch_size", type=int, default=2)
p.add_argument("--gradient_accumulation_steps", type=int, default=1)
p.add_argument("--learning_rate", type=float, default=5e-5)
p.add_argument("--output_dir", type=str, default="runs/gemma_multilingual_pretraining")
p.add_argument("--wandb_project", type=str, default="gemma-multilingual-pretraining")
p.add_argument("--wandb_run_name", type=str, default=None)
p.add_argument("--tokenizer_path", type=str, required=True)
p.add_argument("--data_dir", type=str, default="/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/kk_data/processed_kk_data_all")
p.add_argument("--model_path", type=str, default=None)
return p.parse_args()
args, MAX_LEN = parse_args(), parse_args().max_seq_length
os.environ["WANDB_PROJECT"] = args.wandb_project
# Enable tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# ─── tokenizer ───
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
# For base pretraining, we don't need conversation tokens yet
# They will be added later during SFT
print(f"Tokenizer vocab size: {len(tok)}")
print(f"BOS token: {tok.bos_token} (id: {tok.bos_token_id})")
print(f"EOS token: {tok.eos_token} (id: {tok.eos_token_id})")
print(f"PAD token: {tok.pad_token} (id: {tok.pad_token_id})")
if args.model_path is None:
print("Creating model from scratch...")
# 1B model with your custom vocab
cfg = Gemma3TextConfig(
vocab_size=len(tok), # Use actual tokenizer vocab size
bos_token_id=tok.bos_token_id,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
hidden_size=2304,
num_hidden_layers=26,
num_attention_heads=4,
head_dim=256,
intermediate_size=9216,
max_position_embeddings=32768,
_attn_implementation="eager",
torch_dtype="bfloat16",
)
model = Gemma3ForCausalLM(cfg)
model.resize_token_embeddings(len(tok))
else:
print(f"Loading model from {args.model_path}...")
model = Gemma3ForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
_attn_implementation="eager"
)
# ─── dataset ───
# Find all processed JSONL files
data_dir = Path(args.data_dir)
jsonl_files = list(data_dir.glob("*.jsonl"))
if not jsonl_files:
raise ValueError(f"No .jsonl files found in {args.data_dir}")
print(f"Found {len(jsonl_files)} JSONL files:")
for f in jsonl_files:
print(f" - {f.name}")
# Load all JSONL files
print("Loading datasets...")
datasets = []
for jsonl_file in jsonl_files:
print(f"Loading {jsonl_file.name}...")
ds = load_dataset("json", data_files=str(jsonl_file), split="train")
datasets.append(ds)
print(f" Loaded {len(ds):,} examples")
# Combine all datasets
print("Combining datasets...")
combined_ds = concatenate_datasets(datasets)
print(f"Total examples: {len(combined_ds):,}")
# Use all CPU cores
cpu_workers = os.cpu_count()
print(f"Using {cpu_workers} CPU workers")
def format_for_pretraining(examples):
"""
Format texts for base pretraining - simple BOS + text + EOS
This is the standard format for language model pretraining
"""
formatted_texts = []
for text in examples['text']:
# Simple pretraining format: just BOS + text + EOS
formatted_text = f"{tok.bos_token}{text}{tok.eos_token}"
formatted_texts.append(formatted_text)
return {"text": formatted_texts}
# Format the data for pretraining
print("Formatting data for pretraining...")
ds = combined_ds.map(
format_for_pretraining,
batched=True,
batch_size=1000,
num_proc=cpu_workers,
desc="Formatting texts"
)
# # Optional: Double-check token lengths after formatting
# def check_token_lengths(examples):
# """Check if any examples exceed max length after formatting"""
# tokenized = tok(
# examples["text"],
# truncation=False,
# padding=False,
# return_length=True
# )
# return [length <= MAX_LEN for length in tokenized["length"]]
# print("Checking token lengths after formatting...")
# original_size = len(ds)
# ds = ds.filter(
# check_token_lengths,
# batched=True,
# batch_size=1000,
# num_proc=cpu_workers,
# desc="Filtering by token length"
# )
# final_size = len(ds)
# print(f"After final filtering: {final_size:,} examples (removed {original_size - final_size:,})")
# Shuffle the combined dataset
print("Shuffling dataset...")
ds = ds.shuffle(seed=42)
# To take only part
# subset_size = len(ds) // 10 # 10%
# ds = ds.select(range(subset_size))
num_samples = len(ds)
world = int(os.environ.get("WORLD_SIZE", 1))
eff_batch = args.per_device_batch_size * args.gradient_accumulation_steps * world
max_steps = math.ceil(num_samples / eff_batch)
print(f"Dataset rows: {num_samples:,} | effective batch: {eff_batch} | max_steps: {max_steps}")
# ─── collator ───
collator = DataCollatorForLanguageModeling(
tokenizer=tok,
mlm=False,
)
# ─── trainer ───
train_args = SFTConfig(
dataloader_num_workers=8,
output_dir=args.output_dir,
max_seq_length=args.max_seq_length,
packing=True,
gradient_checkpointing=True,
fp16=False,
bf16=True,
per_device_train_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
max_grad_norm=2.0,
max_steps=-1,
logging_steps=1,
warmup_ratio=0.05,
save_strategy="steps",
save_steps=100,
save_total_limit=15,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
num_train_epochs=1.0,
deepspeed="ds_stage1.json",
report_to="wandb",
run_name=args.wandb_run_name,
group_by_length=False,
overwrite_output_dir=True,
dataset_text_field="text",
# Tell TRL to use all CPUs for its tokenization
dataset_num_proc=cpu_workers,
)
trainer = SFTTrainer(
model=model,
args=train_args,
train_dataset=ds,
data_collator=collator,
processing_class=tok,
formatting_func=None
)
if __name__ == "__main__":
print(f"Starting training with {len(ds):,} examples...")
trainer.train()
trainer.save_model(f"{args.output_dir}/final_checkpoint")
tok.save_pretrained(f"{args.output_dir}/final_checkpoint")
To run training please use similar bash
#bash
export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"
export WANDB_API_KEY=""
OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_1'
WANDB_RUN_NAME='base-model-v1_gemma_1B_test_with_kk_only'
if [ ! -d "$OUTPUT_DIR" ]; then
mkdir -p "$OUTPUT_DIR"
fi
# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 base_train_v1.py \
--tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
--max_seq_length 2048 \
--per_device_batch_size 32 \
--gradient_accumulation_steps 8 \
--learning_rate 3e-4 \
--output_dir ${OUTPUT_DIR} \
--wandb_project "small_llm_SRP" \
--wandb_run_name ${WANDB_RUN_NAME}
Notes: path /scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_1/checkpoint-1172