# train_fixed_clean_keys_v2.py import os import math import random import torch import pandas as pd import numpy as np import streamlit as st import plotly.graph_objects as go from transformers import ( RobertaConfig, RobertaForMaskedLM, Trainer, TrainingArguments, PreTrainedTokenizerFast, DataCollatorForLanguageModeling, TrainerCallback ) # Import Value from datasets alongside others from datasets import load_dataset, Features, Sequence, Value # --- Streamlit setup --- st.set_page_config(layout="wide") # --- Constants --- TOKENIZER_DIR = "tokenizer" # Ensure this matches the one used in preprocessing DATA_PATH = "training_data.jsonl" # Ensure this is the output from sentence_aware_processor.py OUTPUT_DIR = "./checkpoints" VOCAB_SIZE = 32000 MAX_LEN = 512 BATCH_SIZE = 64 EPOCHS = 50 GRAD_ACC = 8 LEARNING_RATE = 1e-3 MLM_PROB = 0.15 SEED = 42 # --- Seed --- def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) set_seed(SEED) # --- Tokenizer --- if not os.path.exists(os.path.join(TOKENIZER_DIR, "tokenizer.json")): st.error(f"Tokenizer not found in {TOKENIZER_DIR}") st.stop() try: tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_DIR) tokenizer.model_max_length = MAX_LEN except Exception as e: st.error(f"Error loading tokenizer from {TOKENIZER_DIR}: {e}") st.stop() # --- Dataset --- # !!! MODIFIED: Updated Features definition to match the JSONL structure !!! features = Features({ 'id': Value(dtype='int64'), # Added id field 'input_ids': Sequence(Value(dtype='int32')), 'source': Value(dtype='string') # Added source field }) # (Error handling remains) try: # Load the dataset using the updated features dataset = load_dataset("json", data_files=DATA_PATH, features=features, split="train") #st.success(f"Loaded dataset from {DATA_PATH} with columns: {dataset.column_names}") except Exception as e: st.error(f"Failed to load dataset from {DATA_PATH}: {e}") st.info(f"Ensure '{DATA_PATH}' exists and matches the features: {features}") st.stop() # --- Add Attention Mask --- # This function remains the same, as it only needs 'input_ids' if 'attention_mask' not in dataset.column_names: def add_attention_mask(example): # The length is derived from the 'input_ids' field example["attention_mask"] = [1] * len(example["input_ids"]) return example dataset = dataset.map(add_attention_mask, num_proc=max(1, os.cpu_count() // 2)) #st.info("Added 'attention_mask' column.") # --- Collator --- # DataCollatorForLanguageModeling will automatically ignore extra columns like 'id' and 'source' collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=MLM_PROB ) # --- Model --- # Model definition remains the same config = RobertaConfig( vocab_size=VOCAB_SIZE, hidden_size=256, num_hidden_layers=4, num_attention_heads=8, intermediate_size=1024, max_position_embeddings=MAX_LEN + 2, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.cls_token_id, eos_token_id=tokenizer.sep_token_id, ) model = RobertaForMaskedLM(config=config) # --- UI State --- # UI setup remains the same log = {"step": [], "loss": [], "grad_norm": [], "perplexity": []} progress = st.empty() col1, col2 = st.columns(2) with col1: chart1_placeholder = st.empty() chart2_placeholder = st.empty() with col2: chart3_placeholder = st.empty() chart4_placeholder = st.empty() # --- Plotting Functions (Unchanged) --- def get_safe_range(values, pad_percent=0.1): values = pd.Series(values).dropna() if values.empty: return (0, 1) if len(values) == 1: return (values.iloc[0] * 0.9, values.iloc[0] * 1.1) numeric_values = pd.to_numeric(values, errors='coerce').dropna() if numeric_values.empty: return (0, 1) low, high = np.percentile(numeric_values, [2, 95]) pad = abs(high - low) * pad_percent return max(0, low - pad), high + pad def forecast_plot(df): if len(df) < 10: return go.Figure(layout_title_text="Loss Forecast (Need more data)") x = pd.to_numeric(df["step"], errors='coerce').dropna().values y = pd.to_numeric(df["loss"], errors='coerce').dropna().values if len(x) < 2 or len(y) < 2 or len(x) != len(y): return go.Figure(layout_title_text="Loss Forecast (Data error)") forecast_x = np.linspace(x[0], x[-1] * 1.5, 300) fig = go.Figure() fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name="Actual Loss")) for percent, color in [(1, 'orange'), (10, 'green'), (50, 'red')]: n = max(5, int(len(x) * percent / 100)) if len(x) >= n and n >= 2: sub_x, sub_y = x[-n:], y[-n:] try: valid_indices = ~np.isnan(sub_x) & ~np.isnan(sub_y) if np.sum(valid_indices) >= 2: m, b = np.polyfit(sub_x[valid_indices], sub_y[valid_indices], 1) y_fit = m * forecast_x + b fig.add_trace(go.Scatter(x=forecast_x, y=y_fit, name=f"{percent}% Trend", line=dict(dash='dot', color=color))) except (np.linalg.LinAlgError, ValueError) as e: print(f"Warning: Could not fit trend for {percent}%: {e}") fig.update_layout(title="Loss Forecast", xaxis_title="Step", yaxis_title="Loss", legend_title_text='Trend % (Recent)') return fig # --- Streamlit Callback (Unchanged) --- class StreamlitCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): if state.is_world_process_zero: if logs is not None and "loss" in logs: step = state.global_step loss = float(logs["loss"]) if isinstance(logs["loss"], (int, float)) else None grad = float(logs.get("grad_norm")) if isinstance(logs.get("grad_norm"), (int, float)) else None if loss is not None: ppl = math.exp(min(loss, 700)) log["step"].append(step) log["loss"].append(loss) log["grad_norm"].append(grad) log["perplexity"].append(ppl) df = pd.DataFrame(log).dropna(subset=['step', 'loss']) if not df.empty: try: r1 = get_safe_range(df["loss"]) r2 = get_safe_range(df["grad_norm"]) r3 = get_safe_range(df["perplexity"]) fig1 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["loss"], mode='lines')) grad_norm_data = df["grad_norm"].dropna() if not grad_norm_data.empty: fig2 = go.Figure().add_trace(go.Scatter(x=df.loc[grad_norm_data.index, "step"], y=grad_norm_data, mode='lines')) else: fig2 = go.Figure() fig3 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["perplexity"], mode='lines')) fig1.update_layout(title="Loss", yaxis_range=r1, xaxis_title="Step", yaxis_title="Loss") fig2.update_layout(title="Gradient Norm", yaxis_range=r2, xaxis_title="Step", yaxis_title="Grad Norm") fig3.update_layout(title="Perplexity", yaxis_range=r3, xaxis_title="Step", yaxis_title="Perplexity") fig4 = forecast_plot(df) chart1_placeholder.plotly_chart(fig1, use_container_width=True, key=f"loss_chart_{step}") chart2_placeholder.plotly_chart(fig2, use_container_width=True, key=f"grad_norm_chart_{step}") chart3_placeholder.plotly_chart(fig3, use_container_width=True, key=f"perplexity_chart_{step}") chart4_placeholder.plotly_chart(fig4, use_container_width=True, key=f"forecast_chart_{step}") except Exception as e: print(f"Error updating Streamlit charts at step {step}: {e}") # --- Training args --- # Training args remain the same args = TrainingArguments( output_dir=OUTPUT_DIR, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACC, num_train_epochs=EPOCHS, learning_rate=LEARNING_RATE, lr_scheduler_type='linear', warmup_ratio=0.1, weight_decay=0.01, max_grad_norm=1.0, save_strategy="steps", save_steps=1000, save_total_limit=10, logging_strategy="steps", logging_steps=10, dataloader_num_workers=4, bf16=torch.cuda.is_bf16_supported(), fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(), seed=SEED, report_to=["none"], # !! Remember to handle checkpoints appropriately for a fresh run !! resume_from_checkpoint=False, # Explicitly set to False for clean run ) # --- Trainer --- # Trainer setup remains the same trainer = Trainer( model=model, args=args, train_dataset=dataset, data_collator=collator, callbacks=[StreamlitCallback()] ) # --- Train --- # Train call remains the same try: # Start training (explicitly not resuming here due to args setting) trainer.train() # No need to pass resume_from_checkpoint if set in args progress.success("✅ Training complete.") st.success("Training finished!") except Exception as e: st.error(f"Training failed: {e}") progress.error("❌ Training stopped due to error.")