dejanseo commited on
Commit
4437e8c
·
verified ·
1 Parent(s): 90751c3

Upload train_medium.py

Browse files
Files changed (1) hide show
  1. train_medium.py +246 -0
train_medium.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_fixed_clean_keys_v2.py
2
+
3
+ import os
4
+ import math
5
+ import random
6
+ import torch
7
+ import pandas as pd
8
+ import numpy as np
9
+ import streamlit as st
10
+ import plotly.graph_objects as go
11
+ from transformers import (
12
+ RobertaConfig, RobertaForMaskedLM, Trainer, TrainingArguments,
13
+ PreTrainedTokenizerFast, DataCollatorForLanguageModeling, TrainerCallback
14
+ )
15
+ # Import Value from datasets alongside others
16
+ from datasets import load_dataset, Features, Sequence, Value
17
+
18
+ # --- Streamlit setup ---
19
+ st.set_page_config(layout="wide")
20
+
21
+ # --- Constants ---
22
+ TOKENIZER_DIR = "tokenizer" # Ensure this matches the one used in preprocessing
23
+ DATA_PATH = "training_data.jsonl" # Ensure this is the output from sentence_aware_processor.py
24
+ OUTPUT_DIR = "./checkpoints_medium"
25
+ VOCAB_SIZE = 32000
26
+ MAX_LEN = 512
27
+ BATCH_SIZE = 64
28
+ EPOCHS = 50
29
+ GRAD_ACC = 8
30
+ LEARNING_RATE = 1e-3
31
+ MLM_PROB = 0.15
32
+ SEED = 42
33
+
34
+ # --- Seed ---
35
+ def set_seed(seed):
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ if torch.cuda.is_available():
40
+ torch.cuda.manual_seed_all(seed)
41
+
42
+ set_seed(SEED)
43
+
44
+ # --- Tokenizer ---
45
+ if not os.path.exists(os.path.join(TOKENIZER_DIR, "tokenizer.json")):
46
+ st.error(f"Tokenizer not found in {TOKENIZER_DIR}")
47
+ st.stop()
48
+ try:
49
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_DIR)
50
+ tokenizer.model_max_length = MAX_LEN
51
+ except Exception as e:
52
+ st.error(f"Error loading tokenizer from {TOKENIZER_DIR}: {e}")
53
+ st.stop()
54
+
55
+
56
+ # --- Dataset ---
57
+ # !!! MODIFIED: Updated Features definition to match the JSONL structure !!!
58
+ features = Features({
59
+ 'id': Value(dtype='int64'), # Added id field
60
+ 'input_ids': Sequence(Value(dtype='int32')),
61
+ 'source': Value(dtype='string') # Added source field
62
+ })
63
+ # (Error handling remains)
64
+ try:
65
+ # Load the dataset using the updated features
66
+ dataset = load_dataset("json", data_files=DATA_PATH, features=features, split="train")
67
+ #st.success(f"Loaded dataset from {DATA_PATH} with columns: {dataset.column_names}")
68
+ except Exception as e:
69
+ st.error(f"Failed to load dataset from {DATA_PATH}: {e}")
70
+ st.info(f"Ensure '{DATA_PATH}' exists and matches the features: {features}")
71
+ st.stop()
72
+
73
+ # --- Add Attention Mask ---
74
+ # This function remains the same, as it only needs 'input_ids'
75
+ if 'attention_mask' not in dataset.column_names:
76
+ def add_attention_mask(example):
77
+ # The length is derived from the 'input_ids' field
78
+ example["attention_mask"] = [1] * len(example["input_ids"])
79
+ return example
80
+ dataset = dataset.map(add_attention_mask, num_proc=max(1, os.cpu_count() // 2))
81
+ #st.info("Added 'attention_mask' column.")
82
+
83
+ # --- Collator ---
84
+ # DataCollatorForLanguageModeling will automatically ignore extra columns like 'id' and 'source'
85
+ collator = DataCollatorForLanguageModeling(
86
+ tokenizer=tokenizer,
87
+ mlm=True,
88
+ mlm_probability=MLM_PROB
89
+ )
90
+
91
+ # --- Model ---
92
+ # Model definition remains the same
93
+ config = RobertaConfig(
94
+ vocab_size=VOCAB_SIZE,
95
+ hidden_size=512,
96
+ num_hidden_layers=8,
97
+ num_attention_heads=16,
98
+ intermediate_size=2048,
99
+ max_position_embeddings=MAX_LEN + 2,
100
+ pad_token_id=tokenizer.pad_token_id,
101
+ bos_token_id=tokenizer.cls_token_id,
102
+ eos_token_id=tokenizer.sep_token_id,
103
+ )
104
+ model = RobertaForMaskedLM(config=config)
105
+
106
+ # --- UI State ---
107
+ # UI setup remains the same
108
+ log = {"step": [], "loss": [], "grad_norm": [], "perplexity": []}
109
+ progress = st.empty()
110
+ col1, col2 = st.columns(2)
111
+ with col1:
112
+ chart1_placeholder = st.empty()
113
+ chart2_placeholder = st.empty()
114
+ with col2:
115
+ chart3_placeholder = st.empty()
116
+ chart4_placeholder = st.empty()
117
+
118
+ # --- Plotting Functions (Unchanged) ---
119
+ def get_safe_range(values, pad_percent=0.1):
120
+ values = pd.Series(values).dropna()
121
+ if values.empty: return (0, 1)
122
+ if len(values) == 1: return (values.iloc[0] * 0.9, values.iloc[0] * 1.1)
123
+ numeric_values = pd.to_numeric(values, errors='coerce').dropna()
124
+ if numeric_values.empty: return (0, 1)
125
+ low, high = np.percentile(numeric_values, [2, 95])
126
+ pad = abs(high - low) * pad_percent
127
+ return max(0, low - pad), high + pad
128
+
129
+ def forecast_plot(df):
130
+ if len(df) < 10: return go.Figure(layout_title_text="Loss Forecast (Need more data)")
131
+ x = pd.to_numeric(df["step"], errors='coerce').dropna().values
132
+ y = pd.to_numeric(df["loss"], errors='coerce').dropna().values
133
+ if len(x) < 2 or len(y) < 2 or len(x) != len(y):
134
+ return go.Figure(layout_title_text="Loss Forecast (Data error)")
135
+
136
+ forecast_x = np.linspace(x[0], x[-1] * 1.5, 300)
137
+ fig = go.Figure()
138
+ fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name="Actual Loss"))
139
+
140
+ for percent, color in [(1, 'orange'), (10, 'green'), (50, 'red')]:
141
+ n = max(5, int(len(x) * percent / 100))
142
+ if len(x) >= n and n >= 2:
143
+ sub_x, sub_y = x[-n:], y[-n:]
144
+ try:
145
+ valid_indices = ~np.isnan(sub_x) & ~np.isnan(sub_y)
146
+ if np.sum(valid_indices) >= 2:
147
+ m, b = np.polyfit(sub_x[valid_indices], sub_y[valid_indices], 1)
148
+ y_fit = m * forecast_x + b
149
+ fig.add_trace(go.Scatter(x=forecast_x, y=y_fit, name=f"{percent}% Trend", line=dict(dash='dot', color=color)))
150
+ except (np.linalg.LinAlgError, ValueError) as e:
151
+ print(f"Warning: Could not fit trend for {percent}%: {e}")
152
+
153
+ fig.update_layout(title="Loss Forecast", xaxis_title="Step", yaxis_title="Loss", legend_title_text='Trend % (Recent)')
154
+ return fig
155
+
156
+ # --- Streamlit Callback (Unchanged) ---
157
+ class StreamlitCallback(TrainerCallback):
158
+ def on_log(self, args, state, control, logs=None, **kwargs):
159
+ if state.is_world_process_zero:
160
+ if logs is not None and "loss" in logs:
161
+ step = state.global_step
162
+ loss = float(logs["loss"]) if isinstance(logs["loss"], (int, float)) else None
163
+ grad = float(logs.get("grad_norm")) if isinstance(logs.get("grad_norm"), (int, float)) else None
164
+
165
+ if loss is not None:
166
+ ppl = math.exp(min(loss, 700))
167
+ log["step"].append(step)
168
+ log["loss"].append(loss)
169
+ log["grad_norm"].append(grad)
170
+ log["perplexity"].append(ppl)
171
+
172
+ df = pd.DataFrame(log).dropna(subset=['step', 'loss'])
173
+ if not df.empty:
174
+ try:
175
+ r1 = get_safe_range(df["loss"])
176
+ r2 = get_safe_range(df["grad_norm"])
177
+ r3 = get_safe_range(df["perplexity"])
178
+
179
+ fig1 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["loss"], mode='lines'))
180
+ grad_norm_data = df["grad_norm"].dropna()
181
+ if not grad_norm_data.empty:
182
+ fig2 = go.Figure().add_trace(go.Scatter(x=df.loc[grad_norm_data.index, "step"], y=grad_norm_data, mode='lines'))
183
+ else:
184
+ fig2 = go.Figure()
185
+ fig3 = go.Figure().add_trace(go.Scatter(x=df["step"], y=df["perplexity"], mode='lines'))
186
+
187
+ fig1.update_layout(title="Loss", yaxis_range=r1, xaxis_title="Step", yaxis_title="Loss")
188
+ fig2.update_layout(title="Gradient Norm", yaxis_range=r2, xaxis_title="Step", yaxis_title="Grad Norm")
189
+ fig3.update_layout(title="Perplexity", yaxis_range=r3, xaxis_title="Step", yaxis_title="Perplexity")
190
+
191
+ fig4 = forecast_plot(df)
192
+
193
+ chart1_placeholder.plotly_chart(fig1, use_container_width=True, key=f"loss_chart_{step}")
194
+ chart2_placeholder.plotly_chart(fig2, use_container_width=True, key=f"grad_norm_chart_{step}")
195
+ chart3_placeholder.plotly_chart(fig3, use_container_width=True, key=f"perplexity_chart_{step}")
196
+ chart4_placeholder.plotly_chart(fig4, use_container_width=True, key=f"forecast_chart_{step}")
197
+ except Exception as e:
198
+ print(f"Error updating Streamlit charts at step {step}: {e}")
199
+
200
+
201
+ # --- Training args ---
202
+ # Training args remain the same
203
+ args = TrainingArguments(
204
+ output_dir=OUTPUT_DIR,
205
+ per_device_train_batch_size=BATCH_SIZE,
206
+ gradient_accumulation_steps=GRAD_ACC,
207
+ num_train_epochs=EPOCHS,
208
+ learning_rate=LEARNING_RATE,
209
+ lr_scheduler_type='linear',
210
+ warmup_ratio=0.1,
211
+ weight_decay=0.01,
212
+ max_grad_norm=1.0,
213
+ save_strategy="steps",
214
+ save_steps=1000,
215
+ save_total_limit=100,
216
+ logging_strategy="steps",
217
+ logging_steps=10,
218
+ dataloader_num_workers=4,
219
+ bf16=torch.cuda.is_bf16_supported(),
220
+ fp16=not torch.cuda.is_bf16_supported() and torch.cuda.is_available(),
221
+ seed=SEED,
222
+ report_to=["none"],
223
+ # !! Remember to handle checkpoints appropriately for a fresh run !!
224
+ resume_from_checkpoint=False, # Explicitly set to False for clean run
225
+ )
226
+
227
+ # --- Trainer ---
228
+ # Trainer setup remains the same
229
+ trainer = Trainer(
230
+ model=model,
231
+ args=args,
232
+ train_dataset=dataset,
233
+ data_collator=collator,
234
+ callbacks=[StreamlitCallback()]
235
+ )
236
+
237
+ # --- Train ---
238
+ # Train call remains the same
239
+ try:
240
+ # Start training (explicitly not resuming here due to args setting)
241
+ trainer.train() # No need to pass resume_from_checkpoint if set in args
242
+ progress.success("✅ Training complete.")
243
+ st.success("Training finished!")
244
+ except Exception as e:
245
+ st.error(f"Training failed: {e}")
246
+ progress.error("❌ Training stopped due to error.")