AvocadoMuffin commited on
Commit
53f26f3
Β·
verified Β·
1 Parent(s): 3368d9b

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +67 -10
train.py CHANGED
@@ -12,9 +12,9 @@ import torch, numpy as np
12
  from datasets import load_dataset, Dataset, disable_caching
13
  from transformers import (
14
  AutoTokenizer, AutoModelForQuestionAnswering,
15
- TrainingArguments, default_data_collator
16
  )
17
- from transformers import QuestionAnsweringTrainer, EvalPrediction
18
  from peft import LoraConfig, get_peft_model, TaskType
19
  import evaluate
20
  from huggingface_hub import login
@@ -80,7 +80,8 @@ def postprocess_qa(examples, features, raw_predictions, tokenizer):
80
  )
81
  return predictions
82
 
83
- def compute_metrics(eval_pred: EvalPrediction):
 
84
  predictions = postprocess_qa(raw_val, val_feats, eval_pred.predictions, tok)
85
  references = [
86
  {"id": ex["id"], "answers": ex["answers"]} for ex in raw_val
@@ -92,7 +93,7 @@ def compute_metrics(eval_pred: EvalPrediction):
92
  def main():
93
  set_seed(SEED)
94
 
95
- # ο£Ώ model name to store on Hub
96
  model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
97
 
98
  if (tokn := os.getenv("roberta_token")):
@@ -124,7 +125,7 @@ def main():
124
 
125
  # ── preprocess ─────────────────────────────────────────────────────
126
  def preprocess(examples):
127
- return tok(
128
  examples["question"],
129
  examples["context"],
130
  truncation="only_second",
@@ -133,7 +134,62 @@ def main():
133
  return_overflowing_tokens=True,
134
  return_offsets_mapping=True,
135
  padding="max_length",
136
- ) | { "example_id": examples["id"] }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  train_feats = train_raw.map(
139
  preprocess, batched=True, remove_columns=train_raw.column_names,
@@ -144,7 +200,7 @@ def main():
144
  num_proc=4, desc="tokenise-val"
145
  )
146
 
147
- global raw_val # for metric fn
148
  raw_val = val_raw
149
 
150
  # ── training args ──────────────────────────────────────────────────
@@ -156,7 +212,7 @@ def main():
156
  per_device_eval_batch_size=8,
157
  gradient_accumulation_steps=4, # eff. BS 32
158
  fp16=False, bf16=True, # L4 = bf16
159
- evaluation_strategy="steps",
160
  eval_steps=250,
161
  save_steps=500,
162
  save_total_limit=2,
@@ -170,7 +226,8 @@ def main():
170
  report_to="none",
171
  )
172
 
173
- trainer = QuestionAnsweringTrainer(
 
174
  model=model,
175
  args=args,
176
  train_dataset=train_feats,
@@ -194,4 +251,4 @@ def main():
194
  print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
195
 
196
  if __name__ == "__main__":
197
- main()
 
12
  from datasets import load_dataset, Dataset, disable_caching
13
  from transformers import (
14
  AutoTokenizer, AutoModelForQuestionAnswering,
15
+ TrainingArguments, default_data_collator, Trainer
16
  )
17
+ # FIXED: Use regular Trainer instead of QuestionAnsweringTrainer
18
  from peft import LoraConfig, get_peft_model, TaskType
19
  import evaluate
20
  from huggingface_hub import login
 
80
  )
81
  return predictions
82
 
83
+ def compute_metrics(eval_pred):
84
+ """FIXED: Use regular eval_pred structure instead of EvalPrediction"""
85
  predictions = postprocess_qa(raw_val, val_feats, eval_pred.predictions, tok)
86
  references = [
87
  {"id": ex["id"], "answers": ex["answers"]} for ex in raw_val
 
93
  def main():
94
  set_seed(SEED)
95
 
96
+ # model name to store on Hub
97
  model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
98
 
99
  if (tokn := os.getenv("roberta_token")):
 
125
 
126
  # ── preprocess ─────────────────────────────────────────────────────
127
  def preprocess(examples):
128
+ tokenized = tok(
129
  examples["question"],
130
  examples["context"],
131
  truncation="only_second",
 
134
  return_overflowing_tokens=True,
135
  return_offsets_mapping=True,
136
  padding="max_length",
137
+ )
138
+
139
+ # FIXED: Add proper answer position computation for QA
140
+ sample_mapping = tokenized.pop("overflow_to_sample_mapping")
141
+ offset_mapping = tokenized.pop("offset_mapping")
142
+
143
+ start_positions = []
144
+ end_positions = []
145
+
146
+ for i, offsets in enumerate(offset_mapping):
147
+ input_ids = tokenized["input_ids"][i]
148
+ cls_index = input_ids.index(tok.cls_token_id)
149
+
150
+ sequence_ids = tokenized.sequence_ids(i)
151
+ sample_index = sample_mapping[i]
152
+ answers = examples["answers"][sample_index]
153
+
154
+ # If no answers are given, set the cls_index as answer
155
+ if len(answers["answer_start"]) == 0:
156
+ start_positions.append(cls_index)
157
+ end_positions.append(cls_index)
158
+ else:
159
+ # Start/end character index of the answer in the text
160
+ start_char = answers["answer_start"][0]
161
+ end_char = start_char + len(answers["text"][0])
162
+
163
+ # Start token index of the current span in the text
164
+ token_start_index = 0
165
+ while sequence_ids[token_start_index] != 1:
166
+ token_start_index += 1
167
+
168
+ # End token index of the current span in the text
169
+ token_end_index = len(input_ids) - 1
170
+ while sequence_ids[token_end_index] != 1:
171
+ token_end_index -= 1
172
+
173
+ # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index)
174
+ if not (offsets[token_start_index][0] <= start_char and
175
+ offsets[token_end_index][1] >= end_char):
176
+ start_positions.append(cls_index)
177
+ end_positions.append(cls_index)
178
+ else:
179
+ # Otherwise move the token_start_index and token_end_index to the two ends of the answer
180
+ # Note: we could go after the last offset if the answer is the last word (edge case)
181
+ while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
182
+ token_start_index += 1
183
+ start_positions.append(token_start_index - 1)
184
+
185
+ while offsets[token_end_index][1] >= end_char:
186
+ token_end_index -= 1
187
+ end_positions.append(token_end_index + 1)
188
+
189
+ tokenized["start_positions"] = start_positions
190
+ tokenized["end_positions"] = end_positions
191
+ tokenized["example_id"] = [examples["id"][sample_mapping[i]] for i in range(len(tokenized["input_ids"]))]
192
+ return tokenized
193
 
194
  train_feats = train_raw.map(
195
  preprocess, batched=True, remove_columns=train_raw.column_names,
 
200
  num_proc=4, desc="tokenise-val"
201
  )
202
 
203
+ global raw_val, val_feats # for metric fn
204
  raw_val = val_raw
205
 
206
  # ── training args ──────────────────────────────────────────────────
 
212
  per_device_eval_batch_size=8,
213
  gradient_accumulation_steps=4, # eff. BS 32
214
  fp16=False, bf16=True, # L4 = bf16
215
+ eval_strategy="steps",
216
  eval_steps=250,
217
  save_steps=500,
218
  save_total_limit=2,
 
226
  report_to="none",
227
  )
228
 
229
+ # FIXED: Use regular Trainer instead of QuestionAnsweringTrainer
230
+ trainer = Trainer(
231
  model=model,
232
  args=args,
233
  train_dataset=train_feats,
 
251
  print("πŸš€ Pushed to:", f"https://huggingface.co/{model_repo}")
252
 
253
  if __name__ == "__main__":
254
+ main()