Upload train_medium.py
Browse files- train_medium.py +246 -0
train_medium.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train_fixed_clean_keys_v2.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import streamlit as st
|
10 |
+
import plotly.graph_objects as go
|
11 |
+
from transformers import (
|
12 |
+
RobertaConfig, RobertaForMaskedLM, Trainer, TrainingArguments,
|
13 |
+
PreTrainedTokenizerFast, DataCollatorForLanguageModeling, TrainerCallback
|
14 |
+
)
|
15 |
+
# Import Value from datasets alongside others
|
16 |
+
from datasets import load_dataset, Features, Sequence, Value
|
17 |
+
|
18 |
+
# --- Streamlit setup ---
|
19 |
+
st.set_page_config(layout="wide")
|
20 |
+
|
21 |
+
# --- Constants ---
|
22 |
+
TOKENIZER_DIR = "tokenizer" # Ensure this matches the one used in preprocessing
|
23 |
+
DATA_PATH = "training_data.jsonl" # Ensure this is the output from sentence_aware_processor.py
|
24 |
+
OUTPUT_DIR = "./checkpoints_medium"
|
25 |
+
VOCAB_SIZE = 32000
|
26 |
+
MAX_LEN = 512
|
27 |
+
BATCH_SIZE = 64
|
28 |
+
EPOCHS = 50
|
29 |
+
GRAD_ACC = 8
|
30 |
+
LEARNING_RATE = 1e-3
|
31 |
+
MLM_PROB = 0.15
|
32 |
+
SEED = 42
|
33 |
+
|
34 |
+
# --- Seed ---
|
35 |
+
def set_seed(seed):
|
36 |
+
random.seed(seed)
|
37 |
+
np.random.seed(seed)
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
torch.cuda.manual_seed_all(seed)
|
41 |
+
|
42 |
+
set_seed(SEED)
|
43 |
+
|
44 |
+
# --- Tokenizer ---
|
45 |
+
if not os.path.exists(os.path.join(TOKENIZER_DIR, "tokenizer.json")):
|
46 |
+
st.error(f"Tokenizer not found in {TOKENIZER_DIR}")
|
47 |
+
st.stop()
|
48 |
+
try:
|
49 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_DIR)
|
50 |
+
tokenizer.model_max_length = MAX_LEN
|
51 |
+
except Exception as e:
|
52 |
+
st.error(f"Error loading tokenizer from {TOKENIZER_DIR}: {e}")
|
53 |
+
st.stop()
|
54 |
+
|
55 |
+
|
56 |
+
# --- Dataset ---
|
57 |
+
# !!! MODIFIED: Updated Features definition to match the JSONL structure !!!
|
58 |
+
features = Features({
|
59 |
+
'id': Value(dtype='int64'), # Added id field
|
60 |
+
'input_ids': Sequence(Value(dtype='int32')),
|
61 |
+
'source': Value(dtype='string') # Added source field
|
62 |
+
})
|
63 |
+
# (Error handling remains)
|
64 |
+
try:
|
65 |
+
# Load the dataset using the updated features
|
66 |
+
dataset = load_dataset("json", data_files=DATA_PATH, features=features, split="train")
|
67 |
+
#st.success(f"Loaded dataset from {DATA_PATH} with columns: {dataset.column_names}")
|
68 |
+
except Exception as e:
|
69 |
+
st.error(f"Failed to load dataset from {DATA_PATH}: {e}")
|
70 |
+
st.info(f"Ensure '{DATA_PATH}' exists and matches the features: {features}")
|
71 |
+
st.stop()
|
72 |
+
|
73 |
+
# --- Add Attention Mask ---
|
74 |
+
# This function remains the same, as it only needs 'input_ids'
|
75 |
+
if 'attention_mask' not in dataset.column_names:
|
76 |
+
def add_attention_mask(example):
|
77 |
+
# The length is derived from the 'input_ids' field
|
78 |
+
example["attention_mask"] = [1] * len(example["input_ids"])
|
79 |
+
return example
|
80 |
+
dataset = dataset.map(add_attention_mask, num_proc=max(1, os.cpu_count() // 2))
|
81 |
+
#st.info("Added 'attention_mask' column.")
|
82 |
+
|
83 |
+
# --- Collator ---
|
84 |
+
# DataCollatorForLanguageModeling will automatically ignore extra columns like 'id' and 'source'
|
85 |
+
collator = DataCollatorForLanguageModeling(
|
86 |
+
tokenizer=tokenizer,
|
87 |
+
mlm=True,
|
88 |
+
mlm_probability=MLM_PROB
|
89 |
+
)
|
90 |
+
|
91 |
+
# --- Model ---
|
92 |
+
# Model definition remains the same
|
93 |
+
config = RobertaConfig(
|
94 |
+
vocab_size=VOCAB_SIZE,
|
95 |
+
hidden_size=512,
|
96 |
+
num_hidden_layers=8,
|
97 |
+
num_attention_heads=16,
|
98 |
+
intermediate_size=2048,
|
99 |
+
max_position_embeddings=MAX_LEN + 2,
|
100 |
+
pad_token_id=tokenizer.pad_token_id,
|
101 |
+
bos_token_id=tokenizer.cls_token_id,
|
102 |
+
eos_token_id=tokenizer.sep_token_id,
|
103 |
+
)
|
104 |
+
model = RobertaForMaskedLM(config=config)
|
105 |
+
|
106 |
+
# --- UI State ---
|
107 |
+
# UI setup remains the same
|
108 |
+
log = {"step": [], "loss": [], "grad_norm": [], "perplexity": []}
|
109 |
+
progress = st.empty()
|
110 |
+
col1, col2 = st.columns(2)
|
111 |
+
with col1:
|
112 |
+
chart1_placeholder = st.empty()
|
113 |
+
chart2_placeholder = st.empty()
|
114 |
+
with col2:
|
115 |
+
chart3_placeholder = st.empty()
|
116 |
+
chart4_placeholder = st.empty()
|
117 |
+
|
118 |
+
# --- Plotting Functions (Unchanged) ---
|
119 |
+
def get_safe_range(values, pad_percent=0.1):
|
120 |
+
values = pd.Series(values).dropna()
|
121 |
+
if values.empty: return (0, 1)
|
122 |
+
if len(values) == 1: return (values.iloc[0] * 0.9, values.iloc[0] * 1.1)
|
123 |
+
numeric_values = pd.to_numeric(values, errors='coerce').dropna()
|
124 |
+
if numeric_values.empty: return (0, 1)
|
125 |
+
low, high = np.percentile(numeric_values, [2, 95])
|
126 |
+
pad = abs(high - low) * pad_percent
|
127 |
+
return max(0, low - pad), high + pad
|
128 |
+
|
129 |
+
def forecast_plot(df):
|
130 |
+
if len(df) < 10: return go.Figure(layout_title_text="Loss Forecast (Need more data)")
|
131 |
+
x = pd.to_numeric(df["step"], errors='coerce').dropna().values
|
132 |
+
y = pd.to_numeric(df["loss"], errors='coerce').dropna().values
|
133 |
+
if len(x) < 2 or len(y) < 2 or len(x) != len(y):
|
134 |
+
return go.Figure(layout_title_text="Loss Forecast (Data error)")
|
135 |
+
|
136 |
+
forecast_x = np.linspace(x[0], x[-1] * 1.5, 300)
|
137 |
+
fig = go.Figure()
|
138 |
+
fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name="Actual Loss"))
|
139 |
+
|
140 |
+
for percent, color in [(1, 'orange'), (10, 'green'), (50, 'red')]:
|
141 |
+
n = max(5, int(len(x) * percent / 100))
|
142 |
+
if len(x) >= n and n >= 2:
|
143 |
+
sub_x, sub_y = x[-n:], y[-n:]
|
144 |
+
try:
|
145 |
+
valid_indices = ~np.isnan(sub_x) & ~np.isnan(sub_y)
|
146 |
+
if np.sum(valid_indices) >= 2:
|
147 |
+
m, b = np.polyfit(sub_x[valid_indices], sub_y[valid_indices], 1)
|
148 |
+
y_fit = m * forecast_x + b
|
149 |
+
fig.add_trace(go.Scatter(x=forecast_x, y=y_fit, name=f"{percent}% Trend", line=dict(dash='dot', color=color)))
|
150 |
+
except (np.linalg.LinAlgError, ValueError) as e:
|
151 |
+
print(f"Warning: Could not fit trend for {percent}%: {e}")
|
152 |
+
|
153 |
+
fig.update_layout(title="Loss Forecast", xaxis_title="Step", yaxis_title="Loss", legend_title_text='Trend % (Recent)')
|
154 |
+
return fig
|
155 |
+
|
156 |
+
# --- Streamlit Callback (Unchanged) ---
|
157 |
+
class StreamlitCallback(TrainerCallback):
|
158 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
159 |
+
if state.is_world_process_zero:
|
160 |
+
if logs is not None and "loss" in logs:
|
161 |
+
step = state.global_step
|
162 |
+
loss = float(logs["loss"]) if isinstance(logs["loss"], (int, float)) else None
|
163 |
+
grad = float(logs.get("grad_norm")) if isinstance(logs.get("grad_norm"), (int, float)) else None
|
164 |
+
|
165 |
+
if loss is not None:
|
166 |
+
ppl = math.exp(min(loss, 700))
|
167 |
+
log["step"].append(step)
|
168 |
+
log["loss"].append(loss)
|
169 |
+
log["grad_norm"].append(grad)
|
170 |
+
log["perplexity"].append(ppl)
|
171 |
+
|
172 |
+
df = pd.DataFrame(log).dropna(subset=['step', 'loss'])
|
173 |
+
if not df.empty:
|
174 |
+
try:
|
175 |
+
r1 = get_safe_range(df["loss"])
|
176 |
+
r2 = get_safe_range(df["grad_norm"])
|
177 |
+
r3 = get_safe_range(df["perplexity"])
|
178 |
+
|
179 |
+
fig1 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["loss"], mode='lines'))
|
180 |
+
grad_norm_data = df["grad_norm"].dropna()
|
181 |
+
if not grad_norm_data.empty:
|
182 |
+
fig2 = go.Figure().add_trace(go.Scatter(x=df.loc[grad_norm_data.index, "step"], y=grad_norm_data, mode='lines'))
|
183 |
+
else:
|
184 |
+
fig2 = go.Figure()
|
185 |
+
fig3 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["perplexity"], mode='lines'))
|
186 |
+
|
187 |
+
fig1.update_layout(title="Loss", yaxis_range=r1, xaxis_title="Step", yaxis_title="Loss")
|
188 |
+
fig2.update_layout(title="Gradient Norm", yaxis_range=r2, xaxis_title="Step", yaxis_title="Grad Norm")
|
189 |
+
fig3.update_layout(title="Perplexity", yaxis_range=r3, xaxis_title="Step", yaxis_title="Perplexity")
|
190 |
+
|
191 |
+
fig4 = forecast_plot(df)
|
192 |
+
|
193 |
+
chart1_placeholder.plotly_chart(fig1, use_container_width=True, key=f"loss_chart_{step}")
|
194 |
+
chart2_placeholder.plotly_chart(fig2, use_container_width=True, key=f"grad_norm_chart_{step}")
|
195 |
+
chart3_placeholder.plotly_chart(fig3, use_container_width=True, key=f"perplexity_chart_{step}")
|
196 |
+
chart4_placeholder.plotly_chart(fig4, use_container_width=True, key=f"forecast_chart_{step}")
|
197 |
+
except Exception as e:
|
198 |
+
print(f"Error updating Streamlit charts at step {step}: {e}")
|
199 |
+
|
200 |
+
|
201 |
+
# --- Training args ---
|
202 |
+
# Training args remain the same
|
203 |
+
args = TrainingArguments(
|
204 |
+
output_dir=OUTPUT_DIR,
|
205 |
+
per_device_train_batch_size=BATCH_SIZE,
|
206 |
+
gradient_accumulation_steps=GRAD_ACC,
|
207 |
+
num_train_epochs=EPOCHS,
|
208 |
+
learning_rate=LEARNING_RATE,
|
209 |
+
lr_scheduler_type='linear',
|
210 |
+
warmup_ratio=0.1,
|
211 |
+
weight_decay=0.01,
|
212 |
+
max_grad_norm=1.0,
|
213 |
+
save_strategy="steps",
|
214 |
+
save_steps=1000,
|
215 |
+
save_total_limit=100,
|
216 |
+
logging_strategy="steps",
|
217 |
+
logging_steps=10,
|
218 |
+
dataloader_num_workers=4,
|
219 |
+
bf16=torch.cuda.is_bf16_supported(),
|
220 |
+
fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
|
221 |
+
seed=SEED,
|
222 |
+
report_to=["none"],
|
223 |
+
# !! Remember to handle checkpoints appropriately for a fresh run !!
|
224 |
+
resume_from_checkpoint=False, # Explicitly set to False for clean run
|
225 |
+
)
|
226 |
+
|
227 |
+
# --- Trainer ---
|
228 |
+
# Trainer setup remains the same
|
229 |
+
trainer = Trainer(
|
230 |
+
model=model,
|
231 |
+
args=args,
|
232 |
+
train_dataset=dataset,
|
233 |
+
data_collator=collator,
|
234 |
+
callbacks=[StreamlitCallback()]
|
235 |
+
)
|
236 |
+
|
237 |
+
# --- Train ---
|
238 |
+
# Train call remains the same
|
239 |
+
try:
|
240 |
+
# Start training (explicitly not resuming here due to args setting)
|
241 |
+
trainer.train() # No need to pass resume_from_checkpoint if set in args
|
242 |
+
progress.success("✅ Training complete.")
|
243 |
+
st.success("Training finished!")
|
244 |
+
except Exception as e:
|
245 |
+
st.error(f"Training failed: {e}")
|
246 |
+
progress.error("❌ Training stopped due to error.")
|