Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# train_cuad_lora_improved.py
|
3 |
"""
|
4 |
-
CUAD fine-tune with LoRA on
|
5 |
-
Improved version with better error handling and
|
6 |
-
Expected wall-clock on Nvidia L4: ~25-30 min.
|
7 |
"""
|
8 |
|
9 |
import os, json, random, gc, time
|
@@ -20,14 +19,15 @@ from peft import LoraConfig, get_peft_model, TaskType
|
|
20 |
import evaluate
|
21 |
from huggingface_hub import login
|
22 |
|
23 |
-
disable_caching()
|
24 |
|
25 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ helpers ββ
|
26 |
|
27 |
-
MAX_LEN = 384
|
28 |
DOC_STRIDE = 128
|
29 |
SEED = 42
|
30 |
CHECKPOINT_DIR = "./cuad_lora_checkpoints"
|
|
|
31 |
|
32 |
def set_seed(seed):
|
33 |
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
|
@@ -43,9 +43,42 @@ def load_checkpoint(checkpoint_path):
|
|
43 |
"""Load preprocessing checkpoint from disk"""
|
44 |
if os.path.exists(checkpoint_path):
|
45 |
print(f"π Loading checkpoint: {checkpoint_path}")
|
46 |
-
return torch.load(checkpoint_path)
|
47 |
return None
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def balance_has_answer(dataset, ratio=2.0):
|
50 |
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
|
51 |
has, no = [], []
|
@@ -105,112 +138,163 @@ def compute_metrics(eval_pred):
|
|
105 |
|
106 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ main ββ
|
107 |
|
108 |
-
def
|
109 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
)
|
123 |
-
|
124 |
-
|
125 |
-
offset_mapping = tokenized["offset_mapping"]
|
126 |
|
127 |
-
#
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
# Find CLS token position (always 0 for RoBERTa)
|
138 |
-
cls_index = 0
|
139 |
-
|
140 |
-
example_ids.append(examples["id"][sample_idx])
|
141 |
-
|
142 |
-
# No answer case
|
143 |
-
if not answers["text"] or not answers["text"][0]:
|
144 |
-
start_positions.append(cls_index)
|
145 |
-
end_positions.append(cls_index)
|
146 |
-
continue
|
147 |
-
|
148 |
-
# Get answer span
|
149 |
-
answer_start = answers["answer_start"][0]
|
150 |
-
answer_text = answers["text"][0]
|
151 |
-
answer_end = answer_start + len(answer_text)
|
152 |
-
|
153 |
-
# Find token positions
|
154 |
-
start_token = end_token = cls_index
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
start_positions.append(cls_index)
|
169 |
-
end_positions.append(cls_index)
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
|
|
|
177 |
|
178 |
-
#
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
183 |
|
184 |
-
#
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
def main():
|
216 |
global val_raw, val_feats, tok
|
@@ -229,8 +313,7 @@ def main():
|
|
229 |
print("π HuggingFace Hub login OK")
|
230 |
except Exception as e:
|
231 |
print(f"β οΈ Hub login failed: {e}")
|
232 |
-
|
233 |
-
tokn = None # Disable pushing
|
234 |
|
235 |
print("π Loading CUADβ¦")
|
236 |
dataset_checkpoint = f"{CHECKPOINT_DIR}/cuad_dataset.pt"
|
@@ -241,17 +324,15 @@ def main():
|
|
241 |
cuad = Dataset.from_dict(dataset_data)
|
242 |
print(f"β
Loaded dataset from checkpoint: {len(cuad)} examples")
|
243 |
else:
|
244 |
-
# Load and process dataset
|
245 |
try:
|
246 |
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
|
247 |
print(f"β
Loaded {len(cuad)} examples")
|
248 |
except Exception as e:
|
249 |
print(f"β Dataset loading failed: {e}")
|
250 |
-
print("π Retrying with cache disabled...")
|
251 |
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
|
252 |
|
253 |
cuad = cuad.shuffle(seed=SEED)
|
254 |
-
cuad = balance_has_answer(cuad, ratio=2.0)
|
255 |
print(f"π Balanced dataset: {len(cuad)} examples")
|
256 |
|
257 |
# Save dataset checkpoint
|
@@ -263,7 +344,7 @@ def main():
|
|
263 |
|
264 |
# ββ tokeniser & model (SQuAD-2 tuned) βββββββββββββββββββββββββββββββ
|
265 |
base_ckpt = "deepset/roberta-base-squad2"
|
266 |
-
tok
|
267 |
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
|
268 |
|
269 |
# LoRA
|
@@ -275,24 +356,30 @@ def main():
|
|
275 |
model = get_peft_model(model, lora)
|
276 |
model.print_trainable_parameters()
|
277 |
|
278 |
-
# ββ preprocess with
|
279 |
-
|
280 |
-
|
|
|
|
|
281 |
if "offset_mapping" in train_feats.column_names:
|
282 |
train_feats = train_feats.remove_columns(["offset_mapping"])
|
283 |
|
284 |
-
val_feats =
|
285 |
-
# Keep offset_mapping for validation
|
|
|
|
|
|
|
|
|
286 |
|
287 |
# ββ training args ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
288 |
args = TrainingArguments(
|
289 |
output_dir="./cuad_lora_out",
|
290 |
learning_rate=3e-5,
|
291 |
num_train_epochs=4,
|
292 |
-
per_device_train_batch_size=
|
293 |
-
per_device_eval_batch_size=
|
294 |
-
gradient_accumulation_steps=
|
295 |
-
fp16=False, bf16=True,
|
296 |
eval_strategy="steps",
|
297 |
eval_steps=250,
|
298 |
save_steps=500,
|
@@ -305,11 +392,9 @@ def main():
|
|
305 |
greater_is_better=True,
|
306 |
logging_steps=50,
|
307 |
report_to="none",
|
308 |
-
# Add resume from checkpoint capability
|
309 |
resume_from_checkpoint=True,
|
310 |
-
#
|
311 |
-
|
312 |
-
dataloader_pin_memory=False, # Reduce memory pressure
|
313 |
)
|
314 |
|
315 |
trainer = Trainer(
|
@@ -328,46 +413,41 @@ def main():
|
|
328 |
print("β
Training completed successfully!")
|
329 |
except Exception as e:
|
330 |
print(f"β Training failed: {e}")
|
331 |
-
print("πΎ Attempting to save current state...")
|
332 |
try:
|
333 |
trainer.save_model("./cuad_lora_out_partial")
|
334 |
tok.save_pretrained("./cuad_lora_out_partial")
|
335 |
-
print("πΎ Partial model saved
|
336 |
except:
|
337 |
print("β Could not save partial model")
|
338 |
raise e
|
339 |
|
340 |
-
print("β
Done.
|
341 |
trainer.save_model("./cuad_lora_out")
|
342 |
tok.save_pretrained("./cuad_lora_out")
|
343 |
|
344 |
-
#
|
345 |
if tokn:
|
346 |
-
|
347 |
-
for push_attempt in range(max_push_retries):
|
348 |
try:
|
349 |
-
print(f"β¬οΈ Pushing to Hub (attempt {
|
350 |
trainer.push_to_hub(model_repo, private=False)
|
351 |
tok.push_to_hub(model_repo, private=False)
|
352 |
print("π Pushed to:", f"https://huggingface.co/{model_repo}")
|
353 |
break
|
354 |
except Exception as e:
|
355 |
-
print(f"β οΈ Hub push failed
|
356 |
-
if
|
357 |
-
print("β³ Waiting 30 seconds before retry...")
|
358 |
time.sleep(30)
|
359 |
else:
|
360 |
-
print("πΎ Model saved locally
|
361 |
-
else:
|
362 |
-
print("πΎ Model saved locally in ./cuad_lora_out (no HF token for push)")
|
363 |
|
364 |
-
# Clean up checkpoints
|
365 |
try:
|
366 |
import shutil
|
367 |
shutil.rmtree(CHECKPOINT_DIR)
|
368 |
print("π§Ή Cleaned up temporary checkpoints")
|
369 |
except:
|
370 |
-
print("β οΈ Could not clean up
|
371 |
|
372 |
if __name__ == "__main__":
|
373 |
main()
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# train_cuad_lora_improved.py
|
3 |
"""
|
4 |
+
CUAD fine-tune with LoRA on L40S GPU in HuggingFace Spaces.
|
5 |
+
Improved version with better error handling and chunked processing.
|
|
|
6 |
"""
|
7 |
|
8 |
import os, json, random, gc, time
|
|
|
19 |
import evaluate
|
20 |
from huggingface_hub import login
|
21 |
|
22 |
+
disable_caching()
|
23 |
|
24 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ helpers ββ
|
25 |
|
26 |
+
MAX_LEN = 384
|
27 |
DOC_STRIDE = 128
|
28 |
SEED = 42
|
29 |
CHECKPOINT_DIR = "./cuad_lora_checkpoints"
|
30 |
+
CHUNK_SIZE = 100 # Process in smaller chunks to avoid timeouts
|
31 |
|
32 |
def set_seed(seed):
|
33 |
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
|
|
|
43 |
"""Load preprocessing checkpoint from disk"""
|
44 |
if os.path.exists(checkpoint_path):
|
45 |
print(f"π Loading checkpoint: {checkpoint_path}")
|
46 |
+
return torch.load(checkpoint_path, map_location='cpu')
|
47 |
return None
|
48 |
|
49 |
+
def save_partial_features(features_dict, chunk_idx, dataset_name):
|
50 |
+
"""Save partial features for a chunk"""
|
51 |
+
partial_path = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
|
52 |
+
save_checkpoint(features_dict, partial_path)
|
53 |
+
return partial_path
|
54 |
+
|
55 |
+
def load_and_combine_chunks(dataset_name):
|
56 |
+
"""Load all chunk files and combine them"""
|
57 |
+
chunk_files = []
|
58 |
+
if os.path.exists(CHECKPOINT_DIR):
|
59 |
+
for f in os.listdir(CHECKPOINT_DIR):
|
60 |
+
if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
|
61 |
+
chunk_files.append(os.path.join(CHECKPOINT_DIR, f))
|
62 |
+
|
63 |
+
if not chunk_files:
|
64 |
+
return None
|
65 |
+
|
66 |
+
chunk_files.sort()
|
67 |
+
print(f"π Found {len(chunk_files)} chunks for {dataset_name}")
|
68 |
+
|
69 |
+
# Combine all chunks
|
70 |
+
combined = None
|
71 |
+
for chunk_file in chunk_files:
|
72 |
+
chunk_data = torch.load(chunk_file, map_location='cpu')
|
73 |
+
if combined is None:
|
74 |
+
combined = chunk_data
|
75 |
+
else:
|
76 |
+
for key in chunk_data:
|
77 |
+
combined[key].extend(chunk_data[key])
|
78 |
+
|
79 |
+
print(f"β
Combined {len(combined['input_ids'])} features from chunks")
|
80 |
+
return combined
|
81 |
+
|
82 |
def balance_has_answer(dataset, ratio=2.0):
|
83 |
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
|
84 |
has, no = [], []
|
|
|
138 |
|
139 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ main ββ
|
140 |
|
141 |
+
def preprocess_single_example(example, tokenizer):
|
142 |
+
"""Process a single example to avoid batch processing issues"""
|
143 |
+
# Tokenize
|
144 |
+
tokenized = tokenizer(
|
145 |
+
example["question"],
|
146 |
+
example["context"],
|
147 |
+
truncation="only_second",
|
148 |
+
max_length=MAX_LEN,
|
149 |
+
stride=DOC_STRIDE,
|
150 |
+
return_overflowing_tokens=True,
|
151 |
+
return_offsets_mapping=True,
|
152 |
+
padding="max_length",
|
153 |
+
)
|
154 |
|
155 |
+
results = {
|
156 |
+
"input_ids": [],
|
157 |
+
"attention_mask": [],
|
158 |
+
"start_positions": [],
|
159 |
+
"end_positions": [],
|
160 |
+
"example_id": [],
|
161 |
+
"offset_mapping": []
|
162 |
+
}
|
163 |
+
|
164 |
+
for i in range(len(tokenized["input_ids"])):
|
165 |
+
results["input_ids"].append(tokenized["input_ids"][i])
|
166 |
+
results["attention_mask"].append(tokenized["attention_mask"][i])
|
167 |
+
results["offset_mapping"].append(tokenized["offset_mapping"][i])
|
168 |
+
results["example_id"].append(example["id"])
|
|
|
169 |
|
170 |
+
# Handle answer positions
|
171 |
+
answers = example["answers"]
|
172 |
+
offsets = tokenized["offset_mapping"][i]
|
173 |
+
cls_index = 0
|
174 |
|
175 |
+
if not answers["text"] or not answers["text"][0]:
|
176 |
+
results["start_positions"].append(cls_index)
|
177 |
+
results["end_positions"].append(cls_index)
|
178 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
+
answer_start = answers["answer_start"][0]
|
181 |
+
answer_text = answers["text"][0]
|
182 |
+
answer_end = answer_start + len(answer_text)
|
183 |
+
|
184 |
+
start_token = end_token = cls_index
|
185 |
+
|
186 |
+
for tok_idx, (start_char, end_char) in enumerate(offsets):
|
187 |
+
if start_char <= answer_start < end_char:
|
188 |
+
start_token = tok_idx
|
189 |
+
if start_char < answer_end <= end_char:
|
190 |
+
end_token = tok_idx
|
191 |
+
break
|
|
|
|
|
192 |
|
193 |
+
if start_token <= end_token and start_token > 0:
|
194 |
+
results["start_positions"].append(start_token)
|
195 |
+
results["end_positions"].append(end_token)
|
196 |
+
else:
|
197 |
+
results["start_positions"].append(cls_index)
|
198 |
+
results["end_positions"].append(cls_index)
|
199 |
+
|
200 |
+
return results
|
201 |
|
202 |
+
def preprocess_with_chunking(dataset, dataset_name, tokenizer):
|
203 |
+
"""Process dataset in chunks to avoid timeouts"""
|
204 |
|
205 |
+
# Check if final result already exists
|
206 |
+
final_checkpoint = f"{CHECKPOINT_DIR}/{dataset_name}_features.pt"
|
207 |
+
final_features = load_checkpoint(final_checkpoint)
|
208 |
+
if final_features is not None:
|
209 |
+
print(f"β
Loaded {dataset_name} features from final checkpoint")
|
210 |
+
return Dataset.from_dict(final_features)
|
211 |
|
212 |
+
# Check if we can resume from chunks
|
213 |
+
combined_features = load_and_combine_chunks(dataset_name)
|
214 |
+
if combined_features is not None:
|
215 |
+
# Save as final checkpoint
|
216 |
+
save_checkpoint(combined_features, final_checkpoint)
|
217 |
+
return Dataset.from_dict(combined_features)
|
218 |
+
|
219 |
+
# Process in chunks
|
220 |
+
print(f"π Processing {dataset_name} dataset in chunks of {CHUNK_SIZE}...")
|
221 |
+
|
222 |
+
total_samples = len(dataset)
|
223 |
+
num_chunks = (total_samples + CHUNK_SIZE - 1) // CHUNK_SIZE
|
224 |
+
|
225 |
+
for chunk_idx in range(num_chunks):
|
226 |
+
chunk_file = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
|
227 |
+
|
228 |
+
# Skip if chunk already processed
|
229 |
+
if os.path.exists(chunk_file):
|
230 |
+
print(f"βοΈ Chunk {chunk_idx + 1}/{num_chunks} already exists, skipping...")
|
231 |
+
continue
|
232 |
+
|
233 |
+
start_idx = chunk_idx * CHUNK_SIZE
|
234 |
+
end_idx = min(start_idx + CHUNK_SIZE, total_samples)
|
235 |
+
|
236 |
+
print(f"π Processing chunk {chunk_idx + 1}/{num_chunks} (samples {start_idx}-{end_idx-1})...")
|
237 |
+
|
238 |
+
chunk_results = {
|
239 |
+
"input_ids": [],
|
240 |
+
"attention_mask": [],
|
241 |
+
"start_positions": [],
|
242 |
+
"end_positions": [],
|
243 |
+
"example_id": [],
|
244 |
+
"offset_mapping": []
|
245 |
+
}
|
246 |
+
|
247 |
+
# Process each example in the chunk individually
|
248 |
+
for i in range(start_idx, end_idx):
|
249 |
+
if i % 10 == 0: # Progress indicator
|
250 |
+
print(f" Processing sample {i}/{total_samples}")
|
251 |
|
252 |
+
try:
|
253 |
+
example = dataset[i]
|
254 |
+
result = preprocess_single_example(example, tokenizer)
|
255 |
+
|
256 |
+
# Add to chunk results
|
257 |
+
for key in chunk_results:
|
258 |
+
chunk_results[key].extend(result[key])
|
259 |
+
|
260 |
+
except Exception as e:
|
261 |
+
print(f"β οΈ Error processing sample {i}: {e}")
|
262 |
+
continue
|
263 |
+
|
264 |
+
# Save chunk
|
265 |
+
save_partial_features(chunk_results, chunk_idx, dataset_name)
|
266 |
+
|
267 |
+
# Clean up memory
|
268 |
+
del chunk_results
|
269 |
+
gc.collect()
|
270 |
+
|
271 |
+
print(f"β
Chunk {chunk_idx + 1}/{num_chunks} completed and saved")
|
272 |
+
|
273 |
+
# Combine all chunks
|
274 |
+
print("π Combining all chunks...")
|
275 |
+
combined_features = load_and_combine_chunks(dataset_name)
|
276 |
+
|
277 |
+
if combined_features is None:
|
278 |
+
raise RuntimeError("Failed to load and combine chunks!")
|
279 |
+
|
280 |
+
# Save final result
|
281 |
+
save_checkpoint(combined_features, final_checkpoint)
|
282 |
+
|
283 |
+
# Clean up chunk files
|
284 |
+
cleanup_chunk_files(dataset_name)
|
285 |
+
|
286 |
+
return Dataset.from_dict(combined_features)
|
287 |
+
|
288 |
+
def cleanup_chunk_files(dataset_name):
|
289 |
+
"""Remove chunk files after successful combination"""
|
290 |
+
if os.path.exists(CHECKPOINT_DIR):
|
291 |
+
for f in os.listdir(CHECKPOINT_DIR):
|
292 |
+
if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
|
293 |
+
try:
|
294 |
+
os.remove(os.path.join(CHECKPOINT_DIR, f))
|
295 |
+
except:
|
296 |
+
pass
|
297 |
+
print(f"π§Ή Cleaned up chunk files for {dataset_name}")
|
298 |
|
299 |
def main():
|
300 |
global val_raw, val_feats, tok
|
|
|
313 |
print("π HuggingFace Hub login OK")
|
314 |
except Exception as e:
|
315 |
print(f"β οΈ Hub login failed: {e}")
|
316 |
+
tokn = None
|
|
|
317 |
|
318 |
print("π Loading CUADβ¦")
|
319 |
dataset_checkpoint = f"{CHECKPOINT_DIR}/cuad_dataset.pt"
|
|
|
324 |
cuad = Dataset.from_dict(dataset_data)
|
325 |
print(f"β
Loaded dataset from checkpoint: {len(cuad)} examples")
|
326 |
else:
|
|
|
327 |
try:
|
328 |
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
|
329 |
print(f"β
Loaded {len(cuad)} examples")
|
330 |
except Exception as e:
|
331 |
print(f"β Dataset loading failed: {e}")
|
|
|
332 |
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
|
333 |
|
334 |
cuad = cuad.shuffle(seed=SEED)
|
335 |
+
cuad = balance_has_answer(cuad, ratio=2.0)
|
336 |
print(f"π Balanced dataset: {len(cuad)} examples")
|
337 |
|
338 |
# Save dataset checkpoint
|
|
|
344 |
|
345 |
# ββ tokeniser & model (SQuAD-2 tuned) βββββββββββββββββββββββββββββββ
|
346 |
base_ckpt = "deepset/roberta-base-squad2"
|
347 |
+
tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
|
348 |
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
|
349 |
|
350 |
# LoRA
|
|
|
356 |
model = get_peft_model(model, lora)
|
357 |
model.print_trainable_parameters()
|
358 |
|
359 |
+
# ββ preprocess with chunking βββββββββββββββββββββββββββββββββββββββββ
|
360 |
+
print("π Starting preprocessing...")
|
361 |
+
|
362 |
+
train_feats = preprocess_with_chunking(train_raw, "train", tok)
|
363 |
+
# Remove offset_mapping from training data
|
364 |
if "offset_mapping" in train_feats.column_names:
|
365 |
train_feats = train_feats.remove_columns(["offset_mapping"])
|
366 |
|
367 |
+
val_feats = preprocess_with_chunking(val_raw, "val", tok)
|
368 |
+
# Keep offset_mapping for validation
|
369 |
+
|
370 |
+
print(f"β
Preprocessing completed!")
|
371 |
+
print(f" Training features: {len(train_feats)}")
|
372 |
+
print(f" Validation features: {len(val_feats)}")
|
373 |
|
374 |
# ββ training args ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
375 |
args = TrainingArguments(
|
376 |
output_dir="./cuad_lora_out",
|
377 |
learning_rate=3e-5,
|
378 |
num_train_epochs=4,
|
379 |
+
per_device_train_batch_size=16, # Increased for L40S
|
380 |
+
per_device_eval_batch_size=16,
|
381 |
+
gradient_accumulation_steps=2, # Reduced since batch size increased
|
382 |
+
fp16=False, bf16=True,
|
383 |
eval_strategy="steps",
|
384 |
eval_steps=250,
|
385 |
save_steps=500,
|
|
|
392 |
greater_is_better=True,
|
393 |
logging_steps=50,
|
394 |
report_to="none",
|
|
|
395 |
resume_from_checkpoint=True,
|
396 |
+
dataloader_num_workers=2, # L40S can handle more workers
|
397 |
+
dataloader_pin_memory=True,
|
|
|
398 |
)
|
399 |
|
400 |
trainer = Trainer(
|
|
|
413 |
print("β
Training completed successfully!")
|
414 |
except Exception as e:
|
415 |
print(f"β Training failed: {e}")
|
|
|
416 |
try:
|
417 |
trainer.save_model("./cuad_lora_out_partial")
|
418 |
tok.save_pretrained("./cuad_lora_out_partial")
|
419 |
+
print("πΎ Partial model saved")
|
420 |
except:
|
421 |
print("β Could not save partial model")
|
422 |
raise e
|
423 |
|
424 |
+
print("β
Done. Best F1:", trainer.state.best_metric)
|
425 |
trainer.save_model("./cuad_lora_out")
|
426 |
tok.save_pretrained("./cuad_lora_out")
|
427 |
|
428 |
+
# Push to hub with retry logic
|
429 |
if tokn:
|
430 |
+
for attempt in range(3):
|
|
|
431 |
try:
|
432 |
+
print(f"β¬οΈ Pushing to Hub (attempt {attempt + 1}/3)...")
|
433 |
trainer.push_to_hub(model_repo, private=False)
|
434 |
tok.push_to_hub(model_repo, private=False)
|
435 |
print("π Pushed to:", f"https://huggingface.co/{model_repo}")
|
436 |
break
|
437 |
except Exception as e:
|
438 |
+
print(f"β οΈ Hub push failed: {e}")
|
439 |
+
if attempt < 2:
|
|
|
440 |
time.sleep(30)
|
441 |
else:
|
442 |
+
print("πΎ Model saved locally (push failed)")
|
|
|
|
|
443 |
|
444 |
+
# Clean up checkpoints
|
445 |
try:
|
446 |
import shutil
|
447 |
shutil.rmtree(CHECKPOINT_DIR)
|
448 |
print("π§Ή Cleaned up temporary checkpoints")
|
449 |
except:
|
450 |
+
print("β οΈ Could not clean up checkpoints")
|
451 |
|
452 |
if __name__ == "__main__":
|
453 |
main()
|