ilyada commited on
Commit
d6e2e1a
·
verified ·
1 Parent(s): ff94b50

Create train_web_accessibility.py

Browse files
Files changed (1) hide show
  1. train_web_accessibility.py +54 -0
train_web_accessibility.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
3
+ import torch
4
+
5
+ # Load the dataset
6
+ dataset = load_dataset("ilyada/web_accessibility_dataset")
7
+
8
+ # Load pre-trained model and tokenizer
9
+ model_name = "bert-base-uncased"
10
+ tokenizer = BertTokenizer.from_pretrained(model_name)
11
+ model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
12
+
13
+ # Tokenize the dataset
14
+ def tokenize_function(examples):
15
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
16
+
17
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
18
+
19
+ # Split the dataset into train and test
20
+ train_test_split = tokenized_datasets["train"].train_test_split(test_size=0.2)
21
+ train_dataset = train_test_split['train']
22
+ test_dataset = train_test_split['test']
23
+
24
+ # Define training arguments
25
+ training_args = TrainingArguments(
26
+ output_dir="./results",
27
+ evaluation_strategy="epoch",
28
+ learning_rate=2e-5,
29
+ per_device_train_batch_size=8,
30
+ per_device_eval_batch_size=8,
31
+ num_train_epochs=3,
32
+ weight_decay=0.01,
33
+ push_to_hub=True, # This enables pushing the model to Hugging Face Hub
34
+ hub_model_id="ilyada/web_accessibility_model", # Replace with your Hugging Face model ID
35
+ hub_strategy="end",
36
+ )
37
+
38
+ # Initialize the Trainer
39
+ trainer = Trainer(
40
+ model=model,
41
+ args=training_args,
42
+ train_dataset=train_dataset,
43
+ eval_dataset=test_dataset,
44
+ )
45
+
46
+ # Train the model
47
+ trainer.train()
48
+
49
+ # Evaluate the model
50
+ results = trainer.evaluate()
51
+ print(results)
52
+
53
+ # Push model to Hugging Face Hub
54
+ trainer.push_to_hub()