|
|
|
|
|
import os
|
|
import math
|
|
import random
|
|
import torch
|
|
import pandas as pd
|
|
import numpy as np
|
|
import streamlit as st
|
|
import plotly.graph_objects as go
|
|
from transformers import (
|
|
RobertaConfig, RobertaForMaskedLM, Trainer, TrainingArguments,
|
|
PreTrainedTokenizerFast, DataCollatorForLanguageModeling, TrainerCallback
|
|
)
|
|
|
|
from datasets import load_dataset, Features, Sequence, Value
|
|
|
|
|
|
st.set_page_config(layout="wide")
|
|
|
|
|
|
TOKENIZER_DIR = "tokenizer"
|
|
DATA_PATH = "training_data.jsonl"
|
|
OUTPUT_DIR = "./checkpoints"
|
|
VOCAB_SIZE = 32000
|
|
MAX_LEN = 512
|
|
BATCH_SIZE = 64
|
|
EPOCHS = 50
|
|
GRAD_ACC = 8
|
|
LEARNING_RATE = 1e-3
|
|
MLM_PROB = 0.15
|
|
SEED = 42
|
|
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
set_seed(SEED)
|
|
|
|
|
|
if not os.path.exists(os.path.join(TOKENIZER_DIR, "tokenizer.json")):
|
|
st.error(f"Tokenizer not found in {TOKENIZER_DIR}")
|
|
st.stop()
|
|
try:
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_DIR)
|
|
tokenizer.model_max_length = MAX_LEN
|
|
except Exception as e:
|
|
st.error(f"Error loading tokenizer from {TOKENIZER_DIR}: {e}")
|
|
st.stop()
|
|
|
|
|
|
|
|
|
|
features = Features({
|
|
'id': Value(dtype='int64'),
|
|
'input_ids': Sequence(Value(dtype='int32')),
|
|
'source': Value(dtype='string')
|
|
})
|
|
|
|
try:
|
|
|
|
dataset = load_dataset("json", data_files=DATA_PATH, features=features, split="train")
|
|
|
|
except Exception as e:
|
|
st.error(f"Failed to load dataset from {DATA_PATH}: {e}")
|
|
st.info(f"Ensure '{DATA_PATH}' exists and matches the features: {features}")
|
|
st.stop()
|
|
|
|
|
|
|
|
if 'attention_mask' not in dataset.column_names:
|
|
def add_attention_mask(example):
|
|
|
|
example["attention_mask"] = [1] * len(example["input_ids"])
|
|
return example
|
|
dataset = dataset.map(add_attention_mask, num_proc=max(1, os.cpu_count() // 2))
|
|
|
|
|
|
|
|
|
|
collator = DataCollatorForLanguageModeling(
|
|
tokenizer=tokenizer,
|
|
mlm=True,
|
|
mlm_probability=MLM_PROB
|
|
)
|
|
|
|
|
|
|
|
config = RobertaConfig(
|
|
vocab_size=VOCAB_SIZE,
|
|
hidden_size=256,
|
|
num_hidden_layers=4,
|
|
num_attention_heads=8,
|
|
intermediate_size=1024,
|
|
max_position_embeddings=MAX_LEN + 2,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
bos_token_id=tokenizer.cls_token_id,
|
|
eos_token_id=tokenizer.sep_token_id,
|
|
)
|
|
model = RobertaForMaskedLM(config=config)
|
|
|
|
|
|
|
|
log = {"step": [], "loss": [], "grad_norm": [], "perplexity": []}
|
|
progress = st.empty()
|
|
col1, col2 = st.columns(2)
|
|
with col1:
|
|
chart1_placeholder = st.empty()
|
|
chart2_placeholder = st.empty()
|
|
with col2:
|
|
chart3_placeholder = st.empty()
|
|
chart4_placeholder = st.empty()
|
|
|
|
|
|
def get_safe_range(values, pad_percent=0.1):
|
|
values = pd.Series(values).dropna()
|
|
if values.empty: return (0, 1)
|
|
if len(values) == 1: return (values.iloc[0] * 0.9, values.iloc[0] * 1.1)
|
|
numeric_values = pd.to_numeric(values, errors='coerce').dropna()
|
|
if numeric_values.empty: return (0, 1)
|
|
low, high = np.percentile(numeric_values, [2, 95])
|
|
pad = abs(high - low) * pad_percent
|
|
return max(0, low - pad), high + pad
|
|
|
|
def forecast_plot(df):
|
|
if len(df) < 10: return go.Figure(layout_title_text="Loss Forecast (Need more data)")
|
|
x = pd.to_numeric(df["step"], errors='coerce').dropna().values
|
|
y = pd.to_numeric(df["loss"], errors='coerce').dropna().values
|
|
if len(x) < 2 or len(y) < 2 or len(x) != len(y):
|
|
return go.Figure(layout_title_text="Loss Forecast (Data error)")
|
|
|
|
forecast_x = np.linspace(x[0], x[-1] * 1.5, 300)
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name="Actual Loss"))
|
|
|
|
for percent, color in [(1, 'orange'), (10, 'green'), (50, 'red')]:
|
|
n = max(5, int(len(x) * percent / 100))
|
|
if len(x) >= n and n >= 2:
|
|
sub_x, sub_y = x[-n:], y[-n:]
|
|
try:
|
|
valid_indices = ~np.isnan(sub_x) & ~np.isnan(sub_y)
|
|
if np.sum(valid_indices) >= 2:
|
|
m, b = np.polyfit(sub_x[valid_indices], sub_y[valid_indices], 1)
|
|
y_fit = m * forecast_x + b
|
|
fig.add_trace(go.Scatter(x=forecast_x, y=y_fit, name=f"{percent}% Trend", line=dict(dash='dot', color=color)))
|
|
except (np.linalg.LinAlgError, ValueError) as e:
|
|
print(f"Warning: Could not fit trend for {percent}%: {e}")
|
|
|
|
fig.update_layout(title="Loss Forecast", xaxis_title="Step", yaxis_title="Loss", legend_title_text='Trend % (Recent)')
|
|
return fig
|
|
|
|
|
|
class StreamlitCallback(TrainerCallback):
|
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
if state.is_world_process_zero:
|
|
if logs is not None and "loss" in logs:
|
|
step = state.global_step
|
|
loss = float(logs["loss"]) if isinstance(logs["loss"], (int, float)) else None
|
|
grad = float(logs.get("grad_norm")) if isinstance(logs.get("grad_norm"), (int, float)) else None
|
|
|
|
if loss is not None:
|
|
ppl = math.exp(min(loss, 700))
|
|
log["step"].append(step)
|
|
log["loss"].append(loss)
|
|
log["grad_norm"].append(grad)
|
|
log["perplexity"].append(ppl)
|
|
|
|
df = pd.DataFrame(log).dropna(subset=['step', 'loss'])
|
|
if not df.empty:
|
|
try:
|
|
r1 = get_safe_range(df["loss"])
|
|
r2 = get_safe_range(df["grad_norm"])
|
|
r3 = get_safe_range(df["perplexity"])
|
|
|
|
fig1 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["loss"], mode='lines'))
|
|
grad_norm_data = df["grad_norm"].dropna()
|
|
if not grad_norm_data.empty:
|
|
fig2 = go.Figure().add_trace(go.Scatter(x=df.loc[grad_norm_data.index, "step"], y=grad_norm_data, mode='lines'))
|
|
else:
|
|
fig2 = go.Figure()
|
|
fig3 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["perplexity"], mode='lines'))
|
|
|
|
fig1.update_layout(title="Loss", yaxis_range=r1, xaxis_title="Step", yaxis_title="Loss")
|
|
fig2.update_layout(title="Gradient Norm", yaxis_range=r2, xaxis_title="Step", yaxis_title="Grad Norm")
|
|
fig3.update_layout(title="Perplexity", yaxis_range=r3, xaxis_title="Step", yaxis_title="Perplexity")
|
|
|
|
fig4 = forecast_plot(df)
|
|
|
|
chart1_placeholder.plotly_chart(fig1, use_container_width=True, key=f"loss_chart_{step}")
|
|
chart2_placeholder.plotly_chart(fig2, use_container_width=True, key=f"grad_norm_chart_{step}")
|
|
chart3_placeholder.plotly_chart(fig3, use_container_width=True, key=f"perplexity_chart_{step}")
|
|
chart4_placeholder.plotly_chart(fig4, use_container_width=True, key=f"forecast_chart_{step}")
|
|
except Exception as e:
|
|
print(f"Error updating Streamlit charts at step {step}: {e}")
|
|
|
|
|
|
|
|
|
|
args = TrainingArguments(
|
|
output_dir=OUTPUT_DIR,
|
|
per_device_train_batch_size=BATCH_SIZE,
|
|
gradient_accumulation_steps=GRAD_ACC,
|
|
num_train_epochs=EPOCHS,
|
|
learning_rate=LEARNING_RATE,
|
|
lr_scheduler_type='linear',
|
|
warmup_ratio=0.1,
|
|
weight_decay=0.01,
|
|
max_grad_norm=1.0,
|
|
save_strategy="steps",
|
|
save_steps=1000,
|
|
save_total_limit=10,
|
|
logging_strategy="steps",
|
|
logging_steps=10,
|
|
dataloader_num_workers=4,
|
|
bf16=torch.cuda.is_bf16_supported(),
|
|
fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
|
|
seed=SEED,
|
|
report_to=["none"],
|
|
|
|
resume_from_checkpoint=False,
|
|
)
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=dataset,
|
|
data_collator=collator,
|
|
callbacks=[StreamlitCallback()]
|
|
)
|
|
|
|
|
|
|
|
try:
|
|
|
|
trainer.train()
|
|
progress.success("✅ Training complete.")
|
|
st.success("Training finished!")
|
|
except Exception as e:
|
|
st.error(f"Training failed: {e}")
|
|
progress.error("❌ Training stopped due to error.") |