Gemma_models
Collection
Gemma models 1B, 2B, sft etc. trained from scratch
•
4 items
•
Updated
June 23
Base model trained on 10B kk,en,ru data.
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"
# Загрузка твоей обученной модели
model_path = "SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# example = {"system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", "user": "<src=ru><tgt=kk>\nЗа один год с тех пор какие изменения произошли в Туркестане, какое дело доведено до конца?", "assistant": "Содан бергі бір жыл ішінде Түркістанда қандай өзгерістер болды, нендей іс тындырылды?"}
# example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nСауда-саттықта салқынқандылық басым.", "assistant": "Composure prevails in trade."}
example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nқала картасы", "assistant": "city map"}
s = example["system"]
u = example["user"]
a = example["assistant"]
tok = tokenizer
# Промпт в формате чата
prompt = (
(f"<start_of_turn>system\n{s}<end_of_turn>\n"
f"<start_of_turn>user\n{u}<end_of_turn>\n"
f"<start_of_turn>assistant"))
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(
**model_inputs,
max_new_tokens=64,
do_sample=True,
top_p=0.9,
#temperature=0.7,
#repetition_penalty=1.2,
eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"),
pad_token_id=tok.eos_token_id,
#min_new_tokens=5,
)
generation = generation[0][input_len:]
decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)
Main script for training
# gemma_pretrain_mix_cli.py – balance 50 % KK, 30 % RU, 20 % EN
import os, math, json, argparse
from pathlib import Path
from datasets import (load_dataset, concatenate_datasets,
disable_caching)
from transformers import (AutoTokenizer, Gemma3TextConfig,
Gemma3ForCausalLM,
DataCollatorForLanguageModeling)
from trl import SFTTrainer, SFTConfig
disable_caching()
# ────────── CLI ──────────
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer_path", required=True)
parser.add_argument("--meta_files", nargs=3, required=True,
metavar=("META_KK", "META_RU", "META_EN"),
help="пути к meta_*.json в порядке kk ru en")
parser.add_argument("--output_dir", default="runs/gemma_mix_50_30_20")
parser.add_argument("--model_path")
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--per_device_batch_size", type=int, default=32)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--wandb_project", default="gemma-pretrain")
parser.add_argument("--wandb_run_name")
args = parser.parse_args()
cpu = os.cpu_count()
os.environ["WANDB_PROJECT"] = args.wandb_project
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# ────────── Tokenizer / Model ──────────
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
if args.model_path:
model = Gemma3ForCausalLM.from_pretrained(
args.model_path, torch_dtype="bfloat16", _attn_implementation="eager")
else:
# TODO WRONG
# cfg = Gemma3TextConfig(
# vocab_size=len(tok),
# bos_token_id=tok.bos_token_id, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id,
# hidden_size=2304, num_hidden_layers=26, num_attention_heads=4, head_dim=256,
# intermediate_size=9216, max_position_embeddings=32_768,
# torch_dtype="bfloat16", _attn_implementation="eager")
model = Gemma3ForCausalLM(cfg)
model.resize_token_embeddings(len(tok))
# ────────── Load helper ──────────
def load_meta(path: str):
meta = json.load(open(path))
return concatenate_datasets(
[load_dataset("json", data_files=i["path"], split="train")
for i in meta.values()]
)
kk_ds, ru_ds, en_ds = [load_meta(p) for p in args.meta_files]
print(f"Raw rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}")
# ────────── Target sizes 50 / 30 / 20 ──────────
target_total = int(len(kk_ds) / 0.50) # kk = 50 %
need_ru = int(target_total * 0.30)
need_en = int(target_total * 0.20)
def resize(ds, need):
if len(ds) >= need: # down-sample
return ds.shuffle(seed=42).select(range(need))
reps = need // len(ds) + 1 # up-sample
big = concatenate_datasets([ds] * reps).shuffle(seed=42)
return big.select(range(need))
ru_ds = resize(ru_ds, need_ru)
en_ds = resize(en_ds, need_en)
print(f"Balanced rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}")
# ────────── Merge & preprocess ──────────
ds = concatenate_datasets([kk_ds, ru_ds, en_ds]).shuffle(seed=42)
def add_bos_eos(ex):
return {"text": f"{tok.bos_token}{ex['text']}{tok.eos_token}"}
ds = ds.map(add_bos_eos, num_proc=cpu)
# ────────── Training params ──────────
world = int(os.getenv("WORLD_SIZE", 1))
eff_bs = args.per_device_batch_size * args.grad_acc * world
max_st = math.ceil(len(ds) / eff_bs)
print(f"Dataset={len(ds):,} eff_batch={eff_bs} max_steps={max_st}")
collator = DataCollatorForLanguageModeling(tok, mlm=False)
cfg_t = SFTConfig(
output_dir=args.output_dir,
max_seq_length=args.max_seq_length,
packing=True, bf16=True,
per_device_train_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=0.05,
max_grad_norm=2.0,
max_steps=max_st,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
save_steps=200, save_total_limit=20,
logging_steps=1,
deepspeed="ds_stage1.json",
run_name=args.wandb_run_name,
report_to="wandb",
dataloader_num_workers=8,
dataset_text_field="text",
dataset_num_proc=cpu,
)
trainer = SFTTrainer(model=model, args=cfg_t,
train_dataset=ds, data_collator=collator,
processing_class=tok, formatting_func=None)
if __name__ == "__main__":
print("🚀 Start pre-training 50/30/20")
trainer.train()
trainer.save_model(f"{args.output_dir}/checkpoint-final")
tok.save_pretrained(f"{args.output_dir}/checkpoint-final")
To run training please use similar bash
#bash
export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"
export WANDB_API_KEY=""
OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling'
WANDB_RUN_NAME='base-model-v1_gemma_1B_test_v2_with_kk_en_ru'
if [ ! -d "$OUTPUT_DIR" ]; then
mkdir -p "$OUTPUT_DIR"
fi
# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 base_train_v2_multi.py \
--tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
--max_seq_length 2048 \
--meta_files \
/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_kk.json \
/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_ru.json \
/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_en.json \
--per_device_batch_size 32 \
--gradient_accumulation_steps 8 \
--learning_rate 3e-4 \
--output_dir ${OUTPUT_DIR} \
--wandb_project "small_llm_SRP" \
--wandb_run_name ${WANDB_RUN_NAME}
Meta in such format
"train_en_news_cleaned_v2_splited_processed.jsonl": {
"path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train.jsonl",
"examples": 268890,
"tokens": 92970273
},
"train_en_news_cleaned_v2_splited_processed_2.jsonl": {
"path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train_2.jsonl",
"examples": 268123,
"tokens": 64523423
}
Notes: path /scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-1978