Base model_v2 gemma_3_800M_base_v2_multilingual_10B_data

June 23

Base model trained on 10B kk,en,ru data.

Inference params

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os 
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"
# Загрузка твоей обученной модели
model_path = "SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# example = {"system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", "user": "<src=ru><tgt=kk>\nЗа один год с тех пор какие изменения произошли в Туркестане, какое дело доведено до конца?", "assistant": "Содан бергі бір жыл ішінде Түркістанда қандай өзгерістер болды, нендей іс тындырылды?"}
# example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nСауда-саттықта салқынқандылық басым.", "assistant": "Composure prevails in trade."}
example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nқала картасы", "assistant": "city map"}
s = example["system"]
u = example["user"]
a = example["assistant"]

tok = tokenizer
# Промпт в формате чата
prompt = (
    (f"<start_of_turn>system\n{s}<end_of_turn>\n"
    f"<start_of_turn>user\n{u}<end_of_turn>\n"
    f"<start_of_turn>assistant"))

model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(
    **model_inputs,
    max_new_tokens=64,
    do_sample=True,
    top_p=0.9,
    #temperature=0.7,
    #repetition_penalty=1.2,
    eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"),
    pad_token_id=tok.eos_token_id,
    #min_new_tokens=5,
)
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Train

Main script for training

# gemma_pretrain_mix_cli.py  – balance 50 % KK, 30 % RU, 20 % EN

import os, math, json, argparse
from pathlib import Path
from datasets import (load_dataset, concatenate_datasets,
                      disable_caching)
from transformers import (AutoTokenizer, Gemma3TextConfig,
                          Gemma3ForCausalLM,
                          DataCollatorForLanguageModeling)
from trl import SFTTrainer, SFTConfig

disable_caching()

# ────────── CLI ──────────
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer_path", required=True)
parser.add_argument("--meta_files", nargs=3, required=True,
                    metavar=("META_KK", "META_RU", "META_EN"),
                    help="пути к meta_*.json в порядке kk ru en")
parser.add_argument("--output_dir", default="runs/gemma_mix_50_30_20")
parser.add_argument("--model_path")
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--per_device_batch_size", type=int, default=32)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--wandb_project", default="gemma-pretrain")
parser.add_argument("--wandb_run_name")
args = parser.parse_args()

cpu = os.cpu_count()
os.environ["WANDB_PROJECT"]          = args.wandb_project
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# ────────── Tokenizer / Model ──────────
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)

if args.model_path:
    model = Gemma3ForCausalLM.from_pretrained(
        args.model_path, torch_dtype="bfloat16", _attn_implementation="eager")
else:
    # TODO WRONG
    # cfg = Gemma3TextConfig(
    #     vocab_size=len(tok),
    #     bos_token_id=tok.bos_token_id, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id,
    #     hidden_size=2304, num_hidden_layers=26, num_attention_heads=4, head_dim=256,
    #     intermediate_size=9216, max_position_embeddings=32_768,
    #     torch_dtype="bfloat16", _attn_implementation="eager")
    model = Gemma3ForCausalLM(cfg)
    model.resize_token_embeddings(len(tok))

# ────────── Load helper ──────────
def load_meta(path: str):
    meta = json.load(open(path))
    return concatenate_datasets(
        [load_dataset("json", data_files=i["path"], split="train")
         for i in meta.values()]
    )

kk_ds, ru_ds, en_ds = [load_meta(p) for p in args.meta_files]
print(f"Raw rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}")

# ────────── Target sizes 50 / 30 / 20 ──────────
target_total = int(len(kk_ds) / 0.50)      # kk = 50 %
need_ru = int(target_total * 0.30)
need_en = int(target_total * 0.20)

def resize(ds, need):
    if len(ds) >= need:                       # down-sample
        return ds.shuffle(seed=42).select(range(need))
    reps  = need // len(ds) + 1               # up-sample
    big   = concatenate_datasets([ds] * reps).shuffle(seed=42)
    return big.select(range(need))

ru_ds = resize(ru_ds, need_ru)
en_ds = resize(en_ds, need_en)
print(f"Balanced rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}")

# ────────── Merge & preprocess ──────────
ds = concatenate_datasets([kk_ds, ru_ds, en_ds]).shuffle(seed=42)

def add_bos_eos(ex):
    return {"text": f"{tok.bos_token}{ex['text']}{tok.eos_token}"}
ds = ds.map(add_bos_eos, num_proc=cpu)

# ────────── Training params ──────────
world  = int(os.getenv("WORLD_SIZE", 1))
eff_bs = args.per_device_batch_size * args.grad_acc * world
max_st = math.ceil(len(ds) / eff_bs)
print(f"Dataset={len(ds):,}  eff_batch={eff_bs}  max_steps={max_st}")

collator = DataCollatorForLanguageModeling(tok, mlm=False)
cfg_t = SFTConfig(
    output_dir=args.output_dir,
    max_seq_length=args.max_seq_length,
    packing=True, bf16=True,
    per_device_train_batch_size=args.per_device_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    warmup_ratio=0.05,
    max_grad_norm=2.0,
    max_steps=max_st,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    save_steps=200, save_total_limit=20,
    logging_steps=1,
    deepspeed="ds_stage1.json",
    run_name=args.wandb_run_name,
    report_to="wandb",
    dataloader_num_workers=8,
    dataset_text_field="text",
    dataset_num_proc=cpu,
)

trainer = SFTTrainer(model=model, args=cfg_t,
                     train_dataset=ds, data_collator=collator,
                     processing_class=tok, formatting_func=None)

if __name__ == "__main__":
    print("🚀 Start pre-training 50/30/20")
    trainer.train()
    trainer.save_model(f"{args.output_dir}/checkpoint-final")
    tok.save_pretrained(f"{args.output_dir}/checkpoint-final")

To run training please use similar bash

#bash

export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"

export WANDB_API_KEY=""

OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling'
WANDB_RUN_NAME='base-model-v1_gemma_1B_test_v2_with_kk_en_ru'
if [ ! -d "$OUTPUT_DIR" ]; then
  mkdir -p "$OUTPUT_DIR"
fi

# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 base_train_v2_multi.py \
  --tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
  --max_seq_length 2048 \
  --meta_files \
      /scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_kk.json \
      /scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_ru.json \
      /scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_en.json \
  --per_device_batch_size 32 \
  --gradient_accumulation_steps 8 \
  --learning_rate 3e-4 \
  --output_dir ${OUTPUT_DIR} \
  --wandb_project "small_llm_SRP" \
  --wandb_run_name ${WANDB_RUN_NAME}

Meta in such format

  "train_en_news_cleaned_v2_splited_processed.jsonl": {
    "path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train.jsonl",
    "examples": 268890,
    "tokens": 92970273
  },
    "train_en_news_cleaned_v2_splited_processed_2.jsonl": {
    "path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train_2.jsonl",
    "examples": 268123,
    "tokens": 64523423
  }

Notes: path /scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-1978

Downloads last month
26
Safetensors
Model size
859M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data

Quantizations
1 model

Dataset used to train SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data

Collection including SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data