DEJAN-LM-MEDIUM / train_medium.py
dejanseo's picture
Upload train_medium.py
4437e8c verified
# train_fixed_clean_keys_v2.py
import os
import math
import random
import torch
import pandas as pd
import numpy as np
import streamlit as st
import plotly.graph_objects as go
from transformers import (
RobertaConfig, RobertaForMaskedLM, Trainer, TrainingArguments,
PreTrainedTokenizerFast, DataCollatorForLanguageModeling, TrainerCallback
)
# Import Value from datasets alongside others
from datasets import load_dataset, Features, Sequence, Value
# --- Streamlit setup ---
st.set_page_config(layout="wide")
# --- Constants ---
TOKENIZER_DIR = "tokenizer" # Ensure this matches the one used in preprocessing
DATA_PATH = "training_data.jsonl" # Ensure this is the output from sentence_aware_processor.py
OUTPUT_DIR = "./checkpoints_medium"
VOCAB_SIZE = 32000
MAX_LEN = 512
BATCH_SIZE = 64
EPOCHS = 50
GRAD_ACC = 8
LEARNING_RATE = 1e-3
MLM_PROB = 0.15
SEED = 42
# --- Seed ---
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(SEED)
# --- Tokenizer ---
if not os.path.exists(os.path.join(TOKENIZER_DIR, "tokenizer.json")):
st.error(f"Tokenizer not found in {TOKENIZER_DIR}")
st.stop()
try:
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_DIR)
tokenizer.model_max_length = MAX_LEN
except Exception as e:
st.error(f"Error loading tokenizer from {TOKENIZER_DIR}: {e}")
st.stop()
# --- Dataset ---
# !!! MODIFIED: Updated Features definition to match the JSONL structure !!!
features = Features({
'id': Value(dtype='int64'), # Added id field
'input_ids': Sequence(Value(dtype='int32')),
'source': Value(dtype='string') # Added source field
})
# (Error handling remains)
try:
# Load the dataset using the updated features
dataset = load_dataset("json", data_files=DATA_PATH, features=features, split="train")
#st.success(f"Loaded dataset from {DATA_PATH} with columns: {dataset.column_names}")
except Exception as e:
st.error(f"Failed to load dataset from {DATA_PATH}: {e}")
st.info(f"Ensure '{DATA_PATH}' exists and matches the features: {features}")
st.stop()
# --- Add Attention Mask ---
# This function remains the same, as it only needs 'input_ids'
if 'attention_mask' not in dataset.column_names:
def add_attention_mask(example):
# The length is derived from the 'input_ids' field
example["attention_mask"] = [1] * len(example["input_ids"])
return example
dataset = dataset.map(add_attention_mask, num_proc=max(1, os.cpu_count() // 2))
#st.info("Added 'attention_mask' column.")
# --- Collator ---
# DataCollatorForLanguageModeling will automatically ignore extra columns like 'id' and 'source'
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=MLM_PROB
)
# --- Model ---
# Model definition remains the same
config = RobertaConfig(
vocab_size=VOCAB_SIZE,
hidden_size=512,
num_hidden_layers=8,
num_attention_heads=16,
intermediate_size=2048,
max_position_embeddings=MAX_LEN + 2,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.cls_token_id,
eos_token_id=tokenizer.sep_token_id,
)
model = RobertaForMaskedLM(config=config)
# --- UI State ---
# UI setup remains the same
log = {"step": [], "loss": [], "grad_norm": [], "perplexity": []}
progress = st.empty()
col1, col2 = st.columns(2)
with col1:
chart1_placeholder = st.empty()
chart2_placeholder = st.empty()
with col2:
chart3_placeholder = st.empty()
chart4_placeholder = st.empty()
# --- Plotting Functions (Unchanged) ---
def get_safe_range(values, pad_percent=0.1):
values = pd.Series(values).dropna()
if values.empty: return (0, 1)
if len(values) == 1: return (values.iloc[0] * 0.9, values.iloc[0] * 1.1)
numeric_values = pd.to_numeric(values, errors='coerce').dropna()
if numeric_values.empty: return (0, 1)
low, high = np.percentile(numeric_values, [2, 95])
pad = abs(high - low) * pad_percent
return max(0, low - pad), high + pad
def forecast_plot(df):
if len(df) < 10: return go.Figure(layout_title_text="Loss Forecast (Need more data)")
x = pd.to_numeric(df["step"], errors='coerce').dropna().values
y = pd.to_numeric(df["loss"], errors='coerce').dropna().values
if len(x) < 2 or len(y) < 2 or len(x) != len(y):
return go.Figure(layout_title_text="Loss Forecast (Data error)")
forecast_x = np.linspace(x[0], x[-1] * 1.5, 300)
fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name="Actual Loss"))
for percent, color in [(1, 'orange'), (10, 'green'), (50, 'red')]:
n = max(5, int(len(x) * percent / 100))
if len(x) >= n and n >= 2:
sub_x, sub_y = x[-n:], y[-n:]
try:
valid_indices = ~np.isnan(sub_x) & ~np.isnan(sub_y)
if np.sum(valid_indices) >= 2:
m, b = np.polyfit(sub_x[valid_indices], sub_y[valid_indices], 1)
y_fit = m * forecast_x + b
fig.add_trace(go.Scatter(x=forecast_x, y=y_fit, name=f"{percent}% Trend", line=dict(dash='dot', color=color)))
except (np.linalg.LinAlgError, ValueError) as e:
print(f"Warning: Could not fit trend for {percent}%: {e}")
fig.update_layout(title="Loss Forecast", xaxis_title="Step", yaxis_title="Loss", legend_title_text='Trend % (Recent)')
return fig
# --- Streamlit Callback (Unchanged) ---
class StreamlitCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
if logs is not None and "loss" in logs:
step = state.global_step
loss = float(logs["loss"]) if isinstance(logs["loss"], (int, float)) else None
grad = float(logs.get("grad_norm")) if isinstance(logs.get("grad_norm"), (int, float)) else None
if loss is not None:
ppl = math.exp(min(loss, 700))
log["step"].append(step)
log["loss"].append(loss)
log["grad_norm"].append(grad)
log["perplexity"].append(ppl)
df = pd.DataFrame(log).dropna(subset=['step', 'loss'])
if not df.empty:
try:
r1 = get_safe_range(df["loss"])
r2 = get_safe_range(df["grad_norm"])
r3 = get_safe_range(df["perplexity"])
fig1 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["loss"], mode='lines'))
grad_norm_data = df["grad_norm"].dropna()
if not grad_norm_data.empty:
fig2 = go.Figure().add_trace(go.Scatter(x=df.loc[grad_norm_data.index, "step"], y=grad_norm_data, mode='lines'))
else:
fig2 = go.Figure()
fig3 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["perplexity"], mode='lines'))
fig1.update_layout(title="Loss", yaxis_range=r1, xaxis_title="Step", yaxis_title="Loss")
fig2.update_layout(title="Gradient Norm", yaxis_range=r2, xaxis_title="Step", yaxis_title="Grad Norm")
fig3.update_layout(title="Perplexity", yaxis_range=r3, xaxis_title="Step", yaxis_title="Perplexity")
fig4 = forecast_plot(df)
chart1_placeholder.plotly_chart(fig1, use_container_width=True, key=f"loss_chart_{step}")
chart2_placeholder.plotly_chart(fig2, use_container_width=True, key=f"grad_norm_chart_{step}")
chart3_placeholder.plotly_chart(fig3, use_container_width=True, key=f"perplexity_chart_{step}")
chart4_placeholder.plotly_chart(fig4, use_container_width=True, key=f"forecast_chart_{step}")
except Exception as e:
print(f"Error updating Streamlit charts at step {step}: {e}")
# --- Training args ---
# Training args remain the same
args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACC,
num_train_epochs=EPOCHS,
learning_rate=LEARNING_RATE,
lr_scheduler_type='linear',
warmup_ratio=0.1,
weight_decay=0.01,
max_grad_norm=1.0,
save_strategy="steps",
save_steps=1000,
save_total_limit=100,
logging_strategy="steps",
logging_steps=10,
dataloader_num_workers=4,
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
seed=SEED,
report_to=["none"],
# !! Remember to handle checkpoints appropriately for a fresh run !!
resume_from_checkpoint=False, # Explicitly set to False for clean run
)
# --- Trainer ---
# Trainer setup remains the same
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset,
data_collator=collator,
callbacks=[StreamlitCallback()]
)
# --- Train ---
# Train call remains the same
try:
# Start training (explicitly not resuming here due to args setting)
trainer.train() # No need to pass resume_from_checkpoint if set in args
progress.success("✅ Training complete.")
st.success("Training finished!")
except Exception as e:
st.error(f"Training failed: {e}")
progress.error("❌ Training stopped due to error.")