Gemma_models
Collection
Gemma models 1B, 2B, sft etc. trained from scratch
•
4 items
•
Updated
June 23
Base Model (/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-900) -> just model that was trained on 3 langs data ~4.5B data. SFT trained on Kazparc (kk_to_en, kk_to_ru, ru_to_kk, en_to_kk)
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"
model_path = "SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# example = {"system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", "user": "<src=ru><tgt=kk>\nЗа один год с тех пор какие изменения произошли в Туркестане, какое дело доведено до конца?", "assistant": "Содан бергі бір жыл ішінде Түркістанда қандай өзгерістер болды, нендей іс тындырылды?"}
# example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nСауда-саттықта салқынқандылық басым.", "assistant": "Composure prevails in trade."}
example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nқала картасы", "assistant": "city map"}
s = example["system"]
u = example["user"]
a = example["assistant"]
tok = tokenizer
# Промпт в формате чата
prompt = (
(f"<start_of_turn>system\n{s}<end_of_turn>\n"
f"<start_of_turn>user\n{u}<end_of_turn>\n"
f"<start_of_turn>assistant"))
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(
**model_inputs,
max_new_tokens=64,
do_sample=True,
top_p=0.9,
#temperature=0.7,
#repetition_penalty=1.2,
eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"),
pad_token_id=tok.eos_token_id,
#min_new_tokens=5,
)
generation = generation[0][input_len:]
decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)
Main script for training
# train_gemma_sft.py 🔧
import os, math, argparse, torch
from pathlib import Path
from datasets import load_dataset, concatenate_datasets
from transformers import (AutoTokenizer, Gemma3ForCausalLM)
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
# ─── CLI ────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--tokenizer_path", required=True)
p.add_argument("--model_path",
default="/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_1/checkpoint-300")
p.add_argument("--data_dir", required=True, # *.jsonl with system/user/assistant
help="Folder with SFT jsonl shards")
p.add_argument("--output_dir", default="runs/gemma_sft")
p.add_argument("--max_seq_length", type=int, default=2048)
p.add_argument("--per_device_batch_size", type=int, default=8)
p.add_argument("--gradient_accumulation_steps", type=int, default=4)
p.add_argument("--learning_rate", type=float, default=2e-4)
p.add_argument("--wandb_project", default="gemma-sft")
p.add_argument("--wandb_run_name", default=None)
return p.parse_args()
args = parse_args()
os.environ["WANDB_PROJECT"] = args.wandb_project
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# ─── tokenizer / model ─────────────────────────────────────────────────
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
for t in ["<start_of_turn>", "<end_of_turn>"]:
if t not in tok.get_vocab():
tok.add_special_tokens({"additional_special_tokens": [t]})
model = Gemma3ForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
_attn_implementation="eager"
)
model.resize_token_embeddings(len(tok)) # in case we added tags
# ─── dataset loading ──────────────────────────────────────────────────
data_dir = Path(args.data_dir)
jsonl_files = sorted(data_dir.glob("*.jsonl"))
if not jsonl_files:
raise ValueError("no jsonl found")
print(f"→ Loading {len(jsonl_files)} shards")
dsets = [load_dataset("json", data_files=str(f), split="train")
for f in jsonl_files]
raw_ds = concatenate_datasets(dsets)
# build chat template + rough length filter
MAX_LEN = args.max_seq_length
def build_and_filter_batch(ex):
texts = []
for s,u,a in zip(ex["system"], ex["user"], ex["assistant"]):
if (len(s)+len(u)+len(a)) > MAX_LEN*4: # ≈ char filter
continue
t = (f"<start_of_turn>system\n{s}<end_of_turn>\n"
f"<start_of_turn>user\n{u}<end_of_turn>\n"
f"<start_of_turn>assistant\n{a}<end_of_turn>{tok.eos_token}")
texts.append(t)
return {"text": texts}
cpu = os.cpu_count()
ds = raw_ds.map(build_and_filter_batch,
batched=True, batch_size=1000, num_proc=cpu,
remove_columns=raw_ds.column_names)
ds = ds.shuffle(seed=42)
# ─── collator: mask *только* ответ ассистента ──────────────────────────
collator = DataCollatorForCompletionOnlyLM(
tokenizer=tok,
instruction_template="<start_of_turn>user\n",
response_template="<start_of_turn>assistant\n",
mlm=False,
)
# ─── training args ─────────────────────────────────────────────────────
train_cfg = SFTConfig(
output_dir=args.output_dir,
run_name=args.wandb_run_name,
max_seq_length=args.max_seq_length,
gradient_checkpointing=True,
packing=False,
per_device_train_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
bf16=True,
warmup_ratio=0.03,
weight_decay=0.01,
do_train=True,
group_by_length=True,
lr_scheduler_type="cosine",
logging_steps=1,
save_strategy="steps",
save_steps=500,
save_total_limit=15,
deepspeed="../train_trl/ds_stage1.json",
dataloader_num_workers=8,
dataset_num_proc=cpu,
)
trainer = SFTTrainer(
model=model,
args=train_cfg,
train_dataset=ds,
data_collator=collator,
processing_class=tok,
)
if __name__ == "__main__":
print(f"🚀 Start SFT: {len(ds):,} chat samples")
trainer.train()
trainer.save_model(f"{args.output_dir}/checkpoint-final")
tok.save_pretrained(f"{args.output_dir}/checkpoint-final")
To run training please use similar bash
#bash
export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"
export WANDB_API_KEY=""
OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_sft_with_base_model_v1_2'
WANDB_RUN_NAME='sft_translation_on_test_2_sft_with_base_model_v1_2'
if [ ! -d "$OUTPUT_DIR" ]; then
mkdir -p "$OUTPUT_DIR"
fi
# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 test_sft_train.py \
--tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
--model_path "/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-900" \
--data_dir "/scratch/vladimir_albrekht/projects/smollm/data/sft/kazparc/jsonl/train" \
--max_seq_length 2048 \
--per_device_batch_size 32 \
--gradient_accumulation_steps 8 \
--learning_rate 4e-5 \
--output_dir ${OUTPUT_DIR} \
--wandb_project "small_llm_SRP" \
--wandb_run_name ${WANDB_RUN_NAME}
Notes: path /scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_sft_with_base_model_v1_2/checkpoint-2178