Base model_v1 gemma_3_2B_base_v1_kk_only_5B-data

June 22

Base model trained on 5B mostly cleaned kk data.

Inference params

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os 
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"
# Загрузка твоей обученной модели
model_path = "SRP-base-model-training/gemma_3_2B_base_v1_kk_only_5B-data"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
prompt = ("Сәлем")

model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(
        **model_inputs, 
        max_new_tokens=50, 
        do_sample=True,
        #temperature=0.2,
        top_p=0.9,                # nucleus sampling
        #top_k=50,                 # отфильтровать всё, кроме 50 лучших токенов
        repetition_penalty=1.2, 
        pad_token_id=tokenizer.eos_token_id
    )
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Train

Main script for training

# train_gemma_pretraining.py
import os, math, json, argparse, torch, wandb
from datasets import load_dataset, concatenate_datasets
from transformers import (
    AutoTokenizer, Gemma3TextConfig, Gemma3ForCausalLM,
    DataCollatorForLanguageModeling  
)
from trl import SFTTrainer, SFTConfig
from pathlib import Path

# ─── arguments ───
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--max_seq_length", type=int, default=2048)
    p.add_argument("--per_device_batch_size", type=int, default=2)
    p.add_argument("--gradient_accumulation_steps", type=int, default=1)
    p.add_argument("--learning_rate", type=float, default=5e-5)
    p.add_argument("--output_dir", type=str, default="runs/gemma_multilingual_pretraining")
    p.add_argument("--wandb_project", type=str, default="gemma-multilingual-pretraining")
    p.add_argument("--wandb_run_name", type=str, default=None)
    p.add_argument("--tokenizer_path", type=str, required=True)
    p.add_argument("--data_dir", type=str, default="/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/kk_data/processed_kk_data_all")
    p.add_argument("--model_path", type=str, default=None)
    return p.parse_args()

args, MAX_LEN = parse_args(), parse_args().max_seq_length

os.environ["WANDB_PROJECT"] = args.wandb_project
# Enable tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# ─── tokenizer ───
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)

# For base pretraining, we don't need conversation tokens yet
# They will be added later during SFT
print(f"Tokenizer vocab size: {len(tok)}")
print(f"BOS token: {tok.bos_token} (id: {tok.bos_token_id})")
print(f"EOS token: {tok.eos_token} (id: {tok.eos_token_id})")
print(f"PAD token: {tok.pad_token} (id: {tok.pad_token_id})")

if args.model_path is None:
    print("Creating model from scratch...")
    # 1B model with your custom vocab
    cfg = Gemma3TextConfig(
        vocab_size=len(tok),  # Use actual tokenizer vocab size
        bos_token_id=tok.bos_token_id,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
        hidden_size=2304,
        num_hidden_layers=26,
        num_attention_heads=4,
        head_dim=256,
        intermediate_size=9216,
        max_position_embeddings=32768,
        _attn_implementation="eager",
        torch_dtype="bfloat16",
    )
    model = Gemma3ForCausalLM(cfg)
    model.resize_token_embeddings(len(tok))
else:
    print(f"Loading model from {args.model_path}...")
    model = Gemma3ForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        _attn_implementation="eager"
    )

# ─── dataset ───
# Find all processed JSONL files
data_dir = Path(args.data_dir)
jsonl_files = list(data_dir.glob("*.jsonl"))

if not jsonl_files:
    raise ValueError(f"No .jsonl files found in {args.data_dir}")

print(f"Found {len(jsonl_files)} JSONL files:")
for f in jsonl_files:
    print(f"  - {f.name}")

# Load all JSONL files
print("Loading datasets...")
datasets = []
for jsonl_file in jsonl_files:
    print(f"Loading {jsonl_file.name}...")
    ds = load_dataset("json", data_files=str(jsonl_file), split="train")
    datasets.append(ds)
    print(f"  Loaded {len(ds):,} examples")

# Combine all datasets
print("Combining datasets...")
combined_ds = concatenate_datasets(datasets)
print(f"Total examples: {len(combined_ds):,}")

# Use all CPU cores
cpu_workers = os.cpu_count()
print(f"Using {cpu_workers} CPU workers")

def format_for_pretraining(examples):
    """
    Format texts for base pretraining - simple BOS + text + EOS
    This is the standard format for language model pretraining
    """
    formatted_texts = []
    for text in examples['text']:
        # Simple pretraining format: just BOS + text + EOS
        formatted_text = f"{tok.bos_token}{text}{tok.eos_token}"
        formatted_texts.append(formatted_text)
    
    return {"text": formatted_texts}

# Format the data for pretraining
print("Formatting data for pretraining...")
ds = combined_ds.map(
    format_for_pretraining,
    batched=True,
    batch_size=1000,
    num_proc=cpu_workers,
    desc="Formatting texts"
)

# # Optional: Double-check token lengths after formatting
# def check_token_lengths(examples):
#     """Check if any examples exceed max length after formatting"""
#     tokenized = tok(
#         examples["text"],
#         truncation=False,
#         padding=False,
#         return_length=True
#     )
#     return [length <= MAX_LEN for length in tokenized["length"]]

# print("Checking token lengths after formatting...")
# original_size = len(ds)
# ds = ds.filter(
#     check_token_lengths,
#     batched=True,
#     batch_size=1000,
#     num_proc=cpu_workers,
#     desc="Filtering by token length"
# )
# final_size = len(ds)

# print(f"After final filtering: {final_size:,} examples (removed {original_size - final_size:,})")

# Shuffle the combined dataset
print("Shuffling dataset...")
ds = ds.shuffle(seed=42)
# To take only part
# subset_size = len(ds) // 10  # 10%
# ds = ds.select(range(subset_size))

num_samples = len(ds)
world = int(os.environ.get("WORLD_SIZE", 1))
eff_batch = args.per_device_batch_size * args.gradient_accumulation_steps * world
max_steps = math.ceil(num_samples / eff_batch)

print(f"Dataset rows: {num_samples:,} | effective batch: {eff_batch} | max_steps: {max_steps}")

# ─── collator ───
collator = DataCollatorForLanguageModeling(
    tokenizer=tok,
    mlm=False,
)

# ─── trainer ───
train_args = SFTConfig(
    dataloader_num_workers=8,
    output_dir=args.output_dir,
    max_seq_length=args.max_seq_length,
    packing=True,                       
    gradient_checkpointing=True,
    fp16=False,
    bf16=True,
    per_device_train_batch_size=args.per_device_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    max_grad_norm=2.0,
    max_steps=-1,
    logging_steps=1, 
    warmup_ratio=0.05,  
    save_strategy="steps",
    save_steps=100,   
    save_total_limit=15,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    num_train_epochs=1.0,
    deepspeed="ds_stage1.json",
    report_to="wandb",
    run_name=args.wandb_run_name,
    group_by_length=False,
    overwrite_output_dir=True,
    dataset_text_field="text",
    # Tell TRL to use all CPUs for its tokenization
    dataset_num_proc=cpu_workers,
)

trainer = SFTTrainer(
    model=model,
    args=train_args,
    train_dataset=ds,
    data_collator=collator,
    processing_class=tok,
    formatting_func=None

)

if __name__ == "__main__":
    print(f"Starting training with {len(ds):,} examples...")
    trainer.train()
    trainer.save_model(f"{args.output_dir}/final_checkpoint")
    tok.save_pretrained(f"{args.output_dir}/final_checkpoint")

To run training please use similar bash

#bash

export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"

export WANDB_API_KEY=""

OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_1'
WANDB_RUN_NAME='base-model-v1_gemma_1B_test_with_kk_only'
if [ ! -d "$OUTPUT_DIR" ]; then
  mkdir -p "$OUTPUT_DIR"
fi

# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 base_train_v1.py \
  --tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
  --max_seq_length 2048 \
  --per_device_batch_size 32 \
  --gradient_accumulation_steps 8 \
  --learning_rate 3e-4 \
  --output_dir ${OUTPUT_DIR} \
  --wandb_project "small_llm_SRP" \
  --wandb_run_name ${WANDB_RUN_NAME}

Notes: path /scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_1/checkpoint-1172

Downloads last month
25
Safetensors
Model size
2.13B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train SRP-base-model-training/gemma_3_2B_base_v1_kk_only_5B-data

Collection including SRP-base-model-training/gemma_3_2B_base_v1_kk_only_5B-data