Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
@@ -12,9 +12,9 @@ import torch, numpy as np
|
|
12 |
from datasets import load_dataset, Dataset, disable_caching
|
13 |
from transformers import (
|
14 |
AutoTokenizer, AutoModelForQuestionAnswering,
|
15 |
-
TrainingArguments, default_data_collator
|
16 |
)
|
17 |
-
|
18 |
from peft import LoraConfig, get_peft_model, TaskType
|
19 |
import evaluate
|
20 |
from huggingface_hub import login
|
@@ -80,7 +80,8 @@ def postprocess_qa(examples, features, raw_predictions, tokenizer):
|
|
80 |
)
|
81 |
return predictions
|
82 |
|
83 |
-
def compute_metrics(eval_pred
|
|
|
84 |
predictions = postprocess_qa(raw_val, val_feats, eval_pred.predictions, tok)
|
85 |
references = [
|
86 |
{"id": ex["id"], "answers": ex["answers"]} for ex in raw_val
|
@@ -92,7 +93,7 @@ def compute_metrics(eval_pred: EvalPrediction):
|
|
92 |
def main():
|
93 |
set_seed(SEED)
|
94 |
|
95 |
-
#
|
96 |
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
|
97 |
|
98 |
if (tokn := os.getenv("roberta_token")):
|
@@ -124,7 +125,7 @@ def main():
|
|
124 |
|
125 |
# ββ preprocess βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
126 |
def preprocess(examples):
|
127 |
-
|
128 |
examples["question"],
|
129 |
examples["context"],
|
130 |
truncation="only_second",
|
@@ -133,7 +134,62 @@ def main():
|
|
133 |
return_overflowing_tokens=True,
|
134 |
return_offsets_mapping=True,
|
135 |
padding="max_length",
|
136 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
train_feats = train_raw.map(
|
139 |
preprocess, batched=True, remove_columns=train_raw.column_names,
|
@@ -144,7 +200,7 @@ def main():
|
|
144 |
num_proc=4, desc="tokenise-val"
|
145 |
)
|
146 |
|
147 |
-
global raw_val # for metric fn
|
148 |
raw_val = val_raw
|
149 |
|
150 |
# ββ training args ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
@@ -156,7 +212,7 @@ def main():
|
|
156 |
per_device_eval_batch_size=8,
|
157 |
gradient_accumulation_steps=4, # eff. BS 32
|
158 |
fp16=False, bf16=True, # L4 = bf16
|
159 |
-
|
160 |
eval_steps=250,
|
161 |
save_steps=500,
|
162 |
save_total_limit=2,
|
@@ -170,7 +226,8 @@ def main():
|
|
170 |
report_to="none",
|
171 |
)
|
172 |
|
173 |
-
|
|
|
174 |
model=model,
|
175 |
args=args,
|
176 |
train_dataset=train_feats,
|
@@ -194,4 +251,4 @@ def main():
|
|
194 |
print("π Pushed to:", f"https://huggingface.co/{model_repo}")
|
195 |
|
196 |
if __name__ == "__main__":
|
197 |
-
main()
|
|
|
12 |
from datasets import load_dataset, Dataset, disable_caching
|
13 |
from transformers import (
|
14 |
AutoTokenizer, AutoModelForQuestionAnswering,
|
15 |
+
TrainingArguments, default_data_collator, Trainer
|
16 |
)
|
17 |
+
# FIXED: Use regular Trainer instead of QuestionAnsweringTrainer
|
18 |
from peft import LoraConfig, get_peft_model, TaskType
|
19 |
import evaluate
|
20 |
from huggingface_hub import login
|
|
|
80 |
)
|
81 |
return predictions
|
82 |
|
83 |
+
def compute_metrics(eval_pred):
|
84 |
+
"""FIXED: Use regular eval_pred structure instead of EvalPrediction"""
|
85 |
predictions = postprocess_qa(raw_val, val_feats, eval_pred.predictions, tok)
|
86 |
references = [
|
87 |
{"id": ex["id"], "answers": ex["answers"]} for ex in raw_val
|
|
|
93 |
def main():
|
94 |
set_seed(SEED)
|
95 |
|
96 |
+
# model name to store on Hub
|
97 |
model_repo = os.getenv("MODEL_NAME", "AvocadoMuffin/roberta-cuad-qa-v2")
|
98 |
|
99 |
if (tokn := os.getenv("roberta_token")):
|
|
|
125 |
|
126 |
# ββ preprocess βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
127 |
def preprocess(examples):
|
128 |
+
tokenized = tok(
|
129 |
examples["question"],
|
130 |
examples["context"],
|
131 |
truncation="only_second",
|
|
|
134 |
return_overflowing_tokens=True,
|
135 |
return_offsets_mapping=True,
|
136 |
padding="max_length",
|
137 |
+
)
|
138 |
+
|
139 |
+
# FIXED: Add proper answer position computation for QA
|
140 |
+
sample_mapping = tokenized.pop("overflow_to_sample_mapping")
|
141 |
+
offset_mapping = tokenized.pop("offset_mapping")
|
142 |
+
|
143 |
+
start_positions = []
|
144 |
+
end_positions = []
|
145 |
+
|
146 |
+
for i, offsets in enumerate(offset_mapping):
|
147 |
+
input_ids = tokenized["input_ids"][i]
|
148 |
+
cls_index = input_ids.index(tok.cls_token_id)
|
149 |
+
|
150 |
+
sequence_ids = tokenized.sequence_ids(i)
|
151 |
+
sample_index = sample_mapping[i]
|
152 |
+
answers = examples["answers"][sample_index]
|
153 |
+
|
154 |
+
# If no answers are given, set the cls_index as answer
|
155 |
+
if len(answers["answer_start"]) == 0:
|
156 |
+
start_positions.append(cls_index)
|
157 |
+
end_positions.append(cls_index)
|
158 |
+
else:
|
159 |
+
# Start/end character index of the answer in the text
|
160 |
+
start_char = answers["answer_start"][0]
|
161 |
+
end_char = start_char + len(answers["text"][0])
|
162 |
+
|
163 |
+
# Start token index of the current span in the text
|
164 |
+
token_start_index = 0
|
165 |
+
while sequence_ids[token_start_index] != 1:
|
166 |
+
token_start_index += 1
|
167 |
+
|
168 |
+
# End token index of the current span in the text
|
169 |
+
token_end_index = len(input_ids) - 1
|
170 |
+
while sequence_ids[token_end_index] != 1:
|
171 |
+
token_end_index -= 1
|
172 |
+
|
173 |
+
# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index)
|
174 |
+
if not (offsets[token_start_index][0] <= start_char and
|
175 |
+
offsets[token_end_index][1] >= end_char):
|
176 |
+
start_positions.append(cls_index)
|
177 |
+
end_positions.append(cls_index)
|
178 |
+
else:
|
179 |
+
# Otherwise move the token_start_index and token_end_index to the two ends of the answer
|
180 |
+
# Note: we could go after the last offset if the answer is the last word (edge case)
|
181 |
+
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
|
182 |
+
token_start_index += 1
|
183 |
+
start_positions.append(token_start_index - 1)
|
184 |
+
|
185 |
+
while offsets[token_end_index][1] >= end_char:
|
186 |
+
token_end_index -= 1
|
187 |
+
end_positions.append(token_end_index + 1)
|
188 |
+
|
189 |
+
tokenized["start_positions"] = start_positions
|
190 |
+
tokenized["end_positions"] = end_positions
|
191 |
+
tokenized["example_id"] = [examples["id"][sample_mapping[i]] for i in range(len(tokenized["input_ids"]))]
|
192 |
+
return tokenized
|
193 |
|
194 |
train_feats = train_raw.map(
|
195 |
preprocess, batched=True, remove_columns=train_raw.column_names,
|
|
|
200 |
num_proc=4, desc="tokenise-val"
|
201 |
)
|
202 |
|
203 |
+
global raw_val, val_feats # for metric fn
|
204 |
raw_val = val_raw
|
205 |
|
206 |
# ββ training args ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
212 |
per_device_eval_batch_size=8,
|
213 |
gradient_accumulation_steps=4, # eff. BS 32
|
214 |
fp16=False, bf16=True, # L4 = bf16
|
215 |
+
eval_strategy="steps",
|
216 |
eval_steps=250,
|
217 |
save_steps=500,
|
218 |
save_total_limit=2,
|
|
|
226 |
report_to="none",
|
227 |
)
|
228 |
|
229 |
+
# FIXED: Use regular Trainer instead of QuestionAnsweringTrainer
|
230 |
+
trainer = Trainer(
|
231 |
model=model,
|
232 |
args=args,
|
233 |
train_dataset=train_feats,
|
|
|
251 |
print("π Pushed to:", f"https://huggingface.co/{model_repo}")
|
252 |
|
253 |
if __name__ == "__main__":
|
254 |
+
main()
|