Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
-
#
|
3 |
"""
|
4 |
-
CUAD fine-tune with LoRA
|
5 |
-
|
6 |
"""
|
7 |
|
8 |
import os, json, random, gc, time
|
@@ -21,72 +21,42 @@ from huggingface_hub import login
|
|
21 |
|
22 |
disable_caching()
|
23 |
|
24 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
25 |
|
26 |
-
MAX_LEN
|
27 |
-
DOC_STRIDE
|
28 |
-
SEED
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
def set_seed(seed):
|
33 |
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
|
34 |
torch.cuda.manual_seed_all(seed)
|
35 |
|
36 |
-
def
|
37 |
-
"""Save preprocessing checkpoint to disk"""
|
38 |
-
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
39 |
-
torch.save(data, checkpoint_path)
|
40 |
-
print(f"πΎ Checkpoint saved: {checkpoint_path}")
|
41 |
-
|
42 |
-
def load_checkpoint(checkpoint_path):
|
43 |
-
"""Load preprocessing checkpoint from disk"""
|
44 |
-
if os.path.exists(checkpoint_path):
|
45 |
-
print(f"π Loading checkpoint: {checkpoint_path}")
|
46 |
-
return torch.load(checkpoint_path, map_location='cpu')
|
47 |
-
return None
|
48 |
-
|
49 |
-
def save_partial_features(features_dict, chunk_idx, dataset_name):
|
50 |
-
"""Save partial features for a chunk"""
|
51 |
-
partial_path = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
|
52 |
-
save_checkpoint(features_dict, partial_path)
|
53 |
-
return partial_path
|
54 |
-
|
55 |
-
def load_and_combine_chunks(dataset_name):
|
56 |
-
"""Load all chunk files and combine them"""
|
57 |
-
chunk_files = []
|
58 |
-
if os.path.exists(CHECKPOINT_DIR):
|
59 |
-
for f in os.listdir(CHECKPOINT_DIR):
|
60 |
-
if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
|
61 |
-
chunk_files.append(os.path.join(CHECKPOINT_DIR, f))
|
62 |
-
|
63 |
-
if not chunk_files:
|
64 |
-
return None
|
65 |
-
|
66 |
-
chunk_files.sort()
|
67 |
-
print(f"π Found {len(chunk_files)} chunks for {dataset_name}")
|
68 |
-
|
69 |
-
# Combine all chunks
|
70 |
-
combined = None
|
71 |
-
for chunk_file in chunk_files:
|
72 |
-
chunk_data = torch.load(chunk_file, map_location='cpu')
|
73 |
-
if combined is None:
|
74 |
-
combined = chunk_data
|
75 |
-
else:
|
76 |
-
for key in chunk_data:
|
77 |
-
combined[key].extend(chunk_data[key])
|
78 |
-
|
79 |
-
print(f"β
Combined {len(combined['input_ids'])} features from chunks")
|
80 |
-
return combined
|
81 |
-
|
82 |
-
def balance_has_answer(dataset, ratio=2.0):
|
83 |
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
|
84 |
has, no = [], []
|
85 |
for ex in dataset:
|
86 |
(has if ex["answers"]["text"] else no).append(ex)
|
|
|
|
|
|
|
87 |
k = int(len(has) * ratio)
|
88 |
no = random.sample(no, min(k, len(no)))
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ postproc ββ
|
92 |
|
@@ -104,20 +74,20 @@ def postprocess_qa(examples, features, raw_predictions, tokenizer):
|
|
104 |
|
105 |
for example_idx, example in enumerate(examples):
|
106 |
best_score = -1e9
|
107 |
-
best_span
|
108 |
-
context
|
109 |
|
110 |
for feat_idx in features_per_example[example_idx]:
|
111 |
start_logit = all_start[feat_idx]
|
112 |
-
end_logit
|
113 |
-
offset
|
114 |
|
115 |
start_idx = int(np.argmax(start_logit))
|
116 |
-
end_idx
|
117 |
|
118 |
if start_idx <= end_idx < len(offset):
|
119 |
start_char, _ = offset[start_idx]
|
120 |
-
_, end_char
|
121 |
span = context[start_char:end_char].strip()
|
122 |
score = start_logit[start_idx] + end_logit[end_idx]
|
123 |
if score > best_score and span:
|
@@ -131,19 +101,25 @@ def postprocess_qa(examples, features, raw_predictions, tokenizer):
|
|
131 |
def compute_metrics(eval_pred):
|
132 |
"""Use regular eval_pred structure and correct variable names"""
|
133 |
predictions = postprocess_qa(val_raw, val_feats, eval_pred.predictions, tok)
|
134 |
-
references
|
135 |
{"id": ex["id"], "answers": ex["answers"]} for ex in val_raw
|
136 |
]
|
137 |
return metric.compute(predictions=predictions, references=references)
|
138 |
|
139 |
-
#
|
140 |
|
141 |
-
def
|
142 |
-
"""
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
truncation="only_second",
|
148 |
max_length=MAX_LEN,
|
149 |
stride=DOC_STRIDE,
|
@@ -152,160 +128,89 @@ def preprocess_single_example(example, tokenizer):
|
|
152 |
padding="max_length",
|
153 |
)
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
"offset_mapping": []
|
162 |
-
}
|
163 |
|
164 |
-
for i in
|
165 |
-
|
166 |
-
|
167 |
-
results["offset_mapping"].append(tokenized["offset_mapping"][i])
|
168 |
-
results["example_id"].append(example["id"])
|
169 |
|
170 |
-
#
|
171 |
-
|
172 |
-
|
173 |
-
cls_index = 0
|
174 |
|
|
|
175 |
if not answers["text"] or not answers["text"][0]:
|
176 |
-
|
177 |
-
|
178 |
continue
|
179 |
-
|
180 |
-
|
|
|
181 |
answer_text = answers["text"][0]
|
182 |
-
|
183 |
|
184 |
-
|
|
|
|
|
185 |
|
186 |
-
for
|
187 |
-
if start_char <=
|
188 |
-
|
189 |
-
if start_char <
|
190 |
-
|
191 |
break
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
196 |
else:
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
return results
|
201 |
-
|
202 |
-
def preprocess_with_chunking(dataset, dataset_name, tokenizer):
|
203 |
-
"""Process dataset in chunks to avoid timeouts"""
|
204 |
|
205 |
-
|
206 |
-
|
207 |
-
final_features = load_checkpoint(final_checkpoint)
|
208 |
-
if final_features is not None:
|
209 |
-
print(f"β
Loaded {dataset_name} features from final checkpoint")
|
210 |
-
return Dataset.from_dict(final_features)
|
211 |
|
212 |
-
#
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
save_checkpoint(combined_features, final_checkpoint)
|
217 |
-
return Dataset.from_dict(combined_features)
|
218 |
-
|
219 |
-
# Process in chunks
|
220 |
-
print(f"π Processing {dataset_name} dataset in chunks of {CHUNK_SIZE}...")
|
221 |
-
|
222 |
-
total_samples = len(dataset)
|
223 |
-
num_chunks = (total_samples + CHUNK_SIZE - 1) // CHUNK_SIZE
|
224 |
-
|
225 |
-
for chunk_idx in range(num_chunks):
|
226 |
-
chunk_file = f"{CHECKPOINT_DIR}/{dataset_name}_chunk_{chunk_idx:04d}.pt"
|
227 |
-
|
228 |
-
# Skip if chunk already processed
|
229 |
-
if os.path.exists(chunk_file):
|
230 |
-
print(f"βοΈ Chunk {chunk_idx + 1}/{num_chunks} already exists, skipping...")
|
231 |
-
continue
|
232 |
-
|
233 |
-
start_idx = chunk_idx * CHUNK_SIZE
|
234 |
-
end_idx = min(start_idx + CHUNK_SIZE, total_samples)
|
235 |
-
|
236 |
-
print(f"π Processing chunk {chunk_idx + 1}/{num_chunks} (samples {start_idx}-{end_idx-1})...")
|
237 |
-
|
238 |
-
chunk_results = {
|
239 |
-
"input_ids": [],
|
240 |
-
"attention_mask": [],
|
241 |
-
"start_positions": [],
|
242 |
-
"end_positions": [],
|
243 |
-
"example_id": [],
|
244 |
-
"offset_mapping": []
|
245 |
-
}
|
246 |
-
|
247 |
-
# Process each example in the chunk individually
|
248 |
-
for i in range(start_idx, end_idx):
|
249 |
-
if i % 10 == 0: # Progress indicator
|
250 |
-
print(f" Processing sample {i}/{total_samples}")
|
251 |
-
|
252 |
-
try:
|
253 |
-
example = dataset[i]
|
254 |
-
result = preprocess_single_example(example, tokenizer)
|
255 |
-
|
256 |
-
# Add to chunk results
|
257 |
-
for key in chunk_results:
|
258 |
-
chunk_results[key].extend(result[key])
|
259 |
-
|
260 |
-
except Exception as e:
|
261 |
-
print(f"β οΈ Error processing sample {i}: {e}")
|
262 |
-
continue
|
263 |
-
|
264 |
-
# Save chunk
|
265 |
-
save_partial_features(chunk_results, chunk_idx, dataset_name)
|
266 |
-
|
267 |
-
# Clean up memory
|
268 |
-
del chunk_results
|
269 |
-
gc.collect()
|
270 |
-
|
271 |
-
print(f"β
Chunk {chunk_idx + 1}/{num_chunks} completed and saved")
|
272 |
-
|
273 |
-
# Combine all chunks
|
274 |
-
print("π Combining all chunks...")
|
275 |
-
combined_features = load_and_combine_chunks(dataset_name)
|
276 |
-
|
277 |
-
if combined_features is None:
|
278 |
-
raise RuntimeError("Failed to load and combine chunks!")
|
279 |
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
if os.path.exists(CHECKPOINT_DIR):
|
291 |
-
for f in os.listdir(CHECKPOINT_DIR):
|
292 |
-
if f.startswith(f"{dataset_name}_chunk_") and f.endswith('.pt'):
|
293 |
-
try:
|
294 |
-
os.remove(os.path.join(CHECKPOINT_DIR, f))
|
295 |
-
except:
|
296 |
-
pass
|
297 |
-
print(f"π§Ή Cleaned up chunk files for {dataset_name}")
|
298 |
|
299 |
def main():
|
300 |
global val_raw, val_feats, tok
|
301 |
|
302 |
set_seed(SEED)
|
303 |
-
|
304 |
-
# Create checkpoint directory
|
305 |
-
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
306 |
|
307 |
# Model name to store on Hub
|
308 |
-
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-
|
309 |
|
310 |
if (tokn := os.getenv("roberta_token")):
|
311 |
try:
|
@@ -316,27 +221,19 @@ def main():
|
|
316 |
tokn = None
|
317 |
|
318 |
print("π Loading CUADβ¦")
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
try:
|
328 |
-
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True)
|
329 |
-
print(f"β
Loaded {len(cuad)} examples")
|
330 |
-
except Exception as e:
|
331 |
-
print(f"β Dataset loading failed: {e}")
|
332 |
-
cuad = load_dataset("theatticusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
|
333 |
-
|
334 |
-
cuad = cuad.shuffle(seed=SEED)
|
335 |
-
cuad = balance_has_answer(cuad, ratio=2.0)
|
336 |
-
print(f"π Balanced dataset: {len(cuad)} examples")
|
337 |
-
|
338 |
-
# Save dataset checkpoint
|
339 |
-
save_checkpoint(cuad.to_dict(), dataset_checkpoint)
|
340 |
|
341 |
# train / val 90-10
|
342 |
ds = cuad.train_test_split(test_size=0.1, seed=SEED)
|
@@ -347,42 +244,46 @@ def main():
|
|
347 |
tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
|
348 |
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
|
349 |
|
350 |
-
# LoRA
|
351 |
lora = LoraConfig(
|
352 |
task_type=TaskType.QUESTION_ANS,
|
353 |
-
r=
|
354 |
-
target_modules=["query", "value"],
|
355 |
)
|
356 |
model = get_peft_model(model, lora)
|
357 |
model.print_trainable_parameters()
|
358 |
|
359 |
-
# ββ
|
360 |
-
print("π Starting preprocessing...")
|
361 |
|
362 |
-
|
363 |
-
|
|
|
364 |
if "offset_mapping" in train_feats.column_names:
|
365 |
train_feats = train_feats.remove_columns(["offset_mapping"])
|
366 |
|
367 |
-
|
368 |
-
|
369 |
|
370 |
print(f"β
Preprocessing completed!")
|
371 |
print(f" Training features: {len(train_feats)}")
|
372 |
print(f" Validation features: {len(val_feats)}")
|
373 |
|
374 |
# ββ training args ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
375 |
args = TrainingArguments(
|
376 |
output_dir="./cuad_lora_out",
|
377 |
-
learning_rate=
|
378 |
-
num_train_epochs=4,
|
379 |
-
per_device_train_batch_size=16,
|
380 |
per_device_eval_batch_size=16,
|
381 |
-
gradient_accumulation_steps=2,
|
382 |
fp16=False, bf16=True,
|
383 |
eval_strategy="steps",
|
384 |
-
eval_steps=
|
385 |
-
save_steps=
|
386 |
save_total_limit=2,
|
387 |
weight_decay=0.01,
|
388 |
lr_scheduler_type="cosine",
|
@@ -390,11 +291,11 @@ def main():
|
|
390 |
load_best_model_at_end=True,
|
391 |
metric_for_best_model="f1",
|
392 |
greater_is_better=True,
|
393 |
-
logging_steps=
|
394 |
report_to="none",
|
395 |
-
|
396 |
-
dataloader_num_workers=2, # L40S can handle more workers
|
397 |
dataloader_pin_memory=True,
|
|
|
398 |
)
|
399 |
|
400 |
trainer = Trainer(
|
@@ -441,13 +342,5 @@ def main():
|
|
441 |
else:
|
442 |
print("πΎ Model saved locally (push failed)")
|
443 |
|
444 |
-
# Clean up checkpoints
|
445 |
-
try:
|
446 |
-
import shutil
|
447 |
-
shutil.rmtree(CHECKPOINT_DIR)
|
448 |
-
print("π§Ή Cleaned up temporary checkpoints")
|
449 |
-
except:
|
450 |
-
print("β οΈ Could not clean up checkpoints")
|
451 |
-
|
452 |
if __name__ == "__main__":
|
453 |
main()
|
|
|
1 |
#!/usr/bin/env python
|
2 |
+
# train_cuad_lora_efficient.py
|
3 |
"""
|
4 |
+
CUAD fine-tune with LoRA - Efficient batch processing version.
|
5 |
+
Fixes bottlenecks and uses proper batching instead of chunking.
|
6 |
"""
|
7 |
|
8 |
import os, json, random, gc, time
|
|
|
21 |
|
22 |
disable_caching()
|
23 |
|
24 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ config ββ
|
25 |
|
26 |
+
MAX_LEN = 384
|
27 |
+
DOC_STRIDE = 128
|
28 |
+
SEED = 42
|
29 |
+
BATCH_SIZE = 1000 # Process in larger, more efficient batches
|
30 |
+
|
31 |
+
# Reduced dataset size option
|
32 |
+
USE_SUBSET = True # Set to True to use only 10k examples
|
33 |
+
SUBSET_SIZE = 10000
|
34 |
|
35 |
def set_seed(seed):
|
36 |
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
|
37 |
torch.cuda.manual_seed_all(seed)
|
38 |
|
39 |
+
def balance_has_answer(dataset, ratio=2.0, max_samples=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"""Keep all has-answer rows, down-sample no-answer rows to `ratio`."""
|
41 |
has, no = [], []
|
42 |
for ex in dataset:
|
43 |
(has if ex["answers"]["text"] else no).append(ex)
|
44 |
+
|
45 |
+
print(f"π Original: {len(has)} has-answer, {len(no)} no-answer")
|
46 |
+
|
47 |
k = int(len(has) * ratio)
|
48 |
no = random.sample(no, min(k, len(no)))
|
49 |
+
|
50 |
+
balanced = has + no
|
51 |
+
|
52 |
+
# Apply subset limit if specified
|
53 |
+
if max_samples and len(balanced) > max_samples:
|
54 |
+
balanced = random.sample(balanced, max_samples)
|
55 |
+
print(f"π Reduced to {max_samples} samples for faster training")
|
56 |
+
|
57 |
+
print(f"π Balanced: {len([x for x in balanced if x['answers']['text']])} has-answer, {len([x for x in balanced if not x['answers']['text']])} no-answer")
|
58 |
+
|
59 |
+
return Dataset.from_list(balanced)
|
60 |
|
61 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ postproc ββ
|
62 |
|
|
|
74 |
|
75 |
for example_idx, example in enumerate(examples):
|
76 |
best_score = -1e9
|
77 |
+
best_span = ""
|
78 |
+
context = example["context"]
|
79 |
|
80 |
for feat_idx in features_per_example[example_idx]:
|
81 |
start_logit = all_start[feat_idx]
|
82 |
+
end_logit = all_end[feat_idx]
|
83 |
+
offset = features["offset_mapping"][feat_idx]
|
84 |
|
85 |
start_idx = int(np.argmax(start_logit))
|
86 |
+
end_idx = int(np.argmax(end_logit))
|
87 |
|
88 |
if start_idx <= end_idx < len(offset):
|
89 |
start_char, _ = offset[start_idx]
|
90 |
+
_, end_char = offset[end_idx]
|
91 |
span = context[start_char:end_char].strip()
|
92 |
score = start_logit[start_idx] + end_logit[end_idx]
|
93 |
if score > best_score and span:
|
|
|
101 |
def compute_metrics(eval_pred):
|
102 |
"""Use regular eval_pred structure and correct variable names"""
|
103 |
predictions = postprocess_qa(val_raw, val_feats, eval_pred.predictions, tok)
|
104 |
+
references = [
|
105 |
{"id": ex["id"], "answers": ex["answers"]} for ex in val_raw
|
106 |
]
|
107 |
return metric.compute(predictions=predictions, references=references)
|
108 |
|
109 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ preprocessing ββ
|
110 |
|
111 |
+
def preprocess_batch_efficient(examples, tokenizer):
|
112 |
+
"""
|
113 |
+
Efficient batch preprocessing using HuggingFace's built-in batch processing.
|
114 |
+
This is much faster than processing examples individually.
|
115 |
+
"""
|
116 |
+
questions = examples["question"]
|
117 |
+
contexts = examples["context"]
|
118 |
+
|
119 |
+
# Batch tokenization - this is the key efficiency gain
|
120 |
+
tokenized_examples = tokenizer(
|
121 |
+
questions,
|
122 |
+
contexts,
|
123 |
truncation="only_second",
|
124 |
max_length=MAX_LEN,
|
125 |
stride=DOC_STRIDE,
|
|
|
128 |
padding="max_length",
|
129 |
)
|
130 |
|
131 |
+
# Map back to original examples
|
132 |
+
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
133 |
+
|
134 |
+
# Initialize output
|
135 |
+
start_positions = []
|
136 |
+
end_positions = []
|
|
|
|
|
137 |
|
138 |
+
for i, offsets in enumerate(tokenized_examples["offset_mapping"]):
|
139 |
+
input_ids = tokenized_examples["input_ids"][i]
|
140 |
+
cls_index = 0 # CLS token position
|
|
|
|
|
141 |
|
142 |
+
# Get the original example for this tokenized chunk
|
143 |
+
sample_index = sample_mapping[i]
|
144 |
+
answers = examples["answers"][sample_index]
|
|
|
145 |
|
146 |
+
# Handle cases with no answer
|
147 |
if not answers["text"] or not answers["text"][0]:
|
148 |
+
start_positions.append(cls_index)
|
149 |
+
end_positions.append(cls_index)
|
150 |
continue
|
151 |
+
|
152 |
+
# Find answer span in tokens
|
153 |
+
answer_start_char = answers["answer_start"][0]
|
154 |
answer_text = answers["text"][0]
|
155 |
+
answer_end_char = answer_start_char + len(answer_text)
|
156 |
|
157 |
+
# Find token positions
|
158 |
+
token_start_index = cls_index
|
159 |
+
token_end_index = cls_index
|
160 |
|
161 |
+
for token_index, (start_char, end_char) in enumerate(offsets):
|
162 |
+
if start_char <= answer_start_char < end_char:
|
163 |
+
token_start_index = token_index
|
164 |
+
if start_char < answer_end_char <= end_char:
|
165 |
+
token_end_index = token_index
|
166 |
break
|
167 |
|
168 |
+
# Validate positions
|
169 |
+
if token_start_index <= token_end_index and token_start_index > 0:
|
170 |
+
start_positions.append(token_start_index)
|
171 |
+
end_positions.append(token_end_index)
|
172 |
else:
|
173 |
+
start_positions.append(cls_index)
|
174 |
+
end_positions.append(cls_index)
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
+
tokenized_examples["start_positions"] = start_positions
|
177 |
+
tokenized_examples["end_positions"] = end_positions
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
# Add example IDs for evaluation
|
180 |
+
tokenized_examples["example_id"] = [
|
181 |
+
examples["id"][sample_mapping[i]] for i in range(len(tokenized_examples["input_ids"]))
|
182 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
+
return tokenized_examples
|
185 |
+
|
186 |
+
def preprocess_dataset_streaming(dataset, tokenizer, desc="Processing"):
|
187 |
+
"""
|
188 |
+
Process dataset in batches using HuggingFace's map function with batching.
|
189 |
+
This is much more memory efficient and faster than manual chunking.
|
190 |
+
"""
|
191 |
+
print(f"π {desc} dataset with batch processing...")
|
192 |
|
193 |
+
processed = dataset.map(
|
194 |
+
lambda examples: preprocess_batch_efficient(examples, tokenizer),
|
195 |
+
batched=True,
|
196 |
+
batch_size=BATCH_SIZE,
|
197 |
+
remove_columns=dataset.column_names,
|
198 |
+
desc=desc,
|
199 |
+
num_proc=1, # Use 1 process to avoid memory issues in Spaces
|
200 |
+
)
|
201 |
|
202 |
+
print(f"β
{desc} completed: {len(processed)} features")
|
203 |
+
return processed
|
204 |
+
|
205 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββ main ββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
def main():
|
208 |
global val_raw, val_feats, tok
|
209 |
|
210 |
set_seed(SEED)
|
|
|
|
|
|
|
211 |
|
212 |
# Model name to store on Hub
|
213 |
+
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v3")
|
214 |
|
215 |
if (tokn := os.getenv("roberta_token")):
|
216 |
try:
|
|
|
221 |
tokn = None
|
222 |
|
223 |
print("π Loading CUADβ¦")
|
224 |
+
try:
|
225 |
+
cuad = load_dataset("theatricusproject/cuad-qa", split="train", trust_remote_code=True)
|
226 |
+
print(f"β
Loaded {len(cuad)} examples")
|
227 |
+
except Exception as e:
|
228 |
+
print(f"β Dataset loading failed: {e}")
|
229 |
+
cuad = load_dataset("theatricusproject/cuad-qa", split="train", trust_remote_code=True, download_mode="force_redownload")
|
230 |
|
231 |
+
cuad = cuad.shuffle(seed=SEED)
|
232 |
+
|
233 |
+
# Apply subset reduction if enabled
|
234 |
+
subset_size = SUBSET_SIZE if USE_SUBSET else None
|
235 |
+
cuad = balance_has_answer(cuad, ratio=2.0, max_samples=subset_size)
|
236 |
+
print(f"π Final dataset size: {len(cuad)} examples")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
# train / val 90-10
|
239 |
ds = cuad.train_test_split(test_size=0.1, seed=SEED)
|
|
|
244 |
tok = AutoTokenizer.from_pretrained(base_ckpt, use_fast=True)
|
245 |
model = AutoModelForQuestionAnswering.from_pretrained(base_ckpt)
|
246 |
|
247 |
+
# LoRA with slightly more aggressive settings for smaller dataset
|
248 |
lora = LoraConfig(
|
249 |
task_type=TaskType.QUESTION_ANS,
|
250 |
+
r=32, lora_alpha=64, lora_dropout=0.1, # Increased for better learning with less data
|
251 |
+
target_modules=["query", "value", "key", "dense"], # More modules for better coverage
|
252 |
)
|
253 |
model = get_peft_model(model, lora)
|
254 |
model.print_trainable_parameters()
|
255 |
|
256 |
+
# ββ efficient preprocessing βββββββββββββββββββββββββββββββββββββββββ
|
257 |
+
print("π Starting efficient preprocessing...")
|
258 |
|
259 |
+
# Process training data
|
260 |
+
train_feats = preprocess_dataset_streaming(train_raw, tok, "Training")
|
261 |
+
# Remove offset_mapping for training
|
262 |
if "offset_mapping" in train_feats.column_names:
|
263 |
train_feats = train_feats.remove_columns(["offset_mapping"])
|
264 |
|
265 |
+
# Process validation data (keep offset_mapping for evaluation)
|
266 |
+
val_feats = preprocess_dataset_streaming(val_raw, tok, "Validation")
|
267 |
|
268 |
print(f"β
Preprocessing completed!")
|
269 |
print(f" Training features: {len(train_feats)}")
|
270 |
print(f" Validation features: {len(val_feats)}")
|
271 |
|
272 |
# ββ training args ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
273 |
+
# Adjusted for smaller dataset
|
274 |
+
total_steps = (len(train_feats) // 16 // 2) * 6 # Rough estimate
|
275 |
+
|
276 |
args = TrainingArguments(
|
277 |
output_dir="./cuad_lora_out",
|
278 |
+
learning_rate=5e-5, # Slightly higher for smaller dataset
|
279 |
+
num_train_epochs=6 if USE_SUBSET else 4, # More epochs for smaller dataset
|
280 |
+
per_device_train_batch_size=16,
|
281 |
per_device_eval_batch_size=16,
|
282 |
+
gradient_accumulation_steps=2,
|
283 |
fp16=False, bf16=True,
|
284 |
eval_strategy="steps",
|
285 |
+
eval_steps=max(100, total_steps // 20), # Adaptive eval steps
|
286 |
+
save_steps=max(200, total_steps // 10), # Adaptive save steps
|
287 |
save_total_limit=2,
|
288 |
weight_decay=0.01,
|
289 |
lr_scheduler_type="cosine",
|
|
|
291 |
load_best_model_at_end=True,
|
292 |
metric_for_best_model="f1",
|
293 |
greater_is_better=True,
|
294 |
+
logging_steps=25,
|
295 |
report_to="none",
|
296 |
+
dataloader_num_workers=2,
|
|
|
297 |
dataloader_pin_memory=True,
|
298 |
+
remove_unused_columns=False, # Keep example_id for evaluation
|
299 |
)
|
300 |
|
301 |
trainer = Trainer(
|
|
|
342 |
else:
|
343 |
print("πΎ Model saved locally (push failed)")
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
if __name__ == "__main__":
|
346 |
main()
|