DeFamy commited on
Commit
0af51d7
·
1 Parent(s): d69b533

Upload train_model.ipynb

Browse files
Files changed (1) hide show
  1. train_model.ipynb +909 -0
train_model.ipynb ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "XLhB2j_Hemio"
7
+ },
8
+ "source": [
9
+ "## Read the dataset csv file"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "hgYEtrYgemir",
17
+ "outputId": "d3ddedc7-8bd7-4ba9-c82e-68e4eb1309c3"
18
+ },
19
+ "outputs": [
20
+ {
21
+ "data": {
22
+ "text/html": [
23
+ "<div>\n",
24
+ "<style scoped>\n",
25
+ " .dataframe tbody tr th:only-of-type {\n",
26
+ " vertical-align: middle;\n",
27
+ " }\n",
28
+ "\n",
29
+ " .dataframe tbody tr th {\n",
30
+ " vertical-align: top;\n",
31
+ " }\n",
32
+ "\n",
33
+ " .dataframe thead th {\n",
34
+ " text-align: right;\n",
35
+ " }\n",
36
+ "</style>\n",
37
+ "<table border=\"1\" class=\"dataframe\">\n",
38
+ " <thead>\n",
39
+ " <tr style=\"text-align: right;\">\n",
40
+ " <th></th>\n",
41
+ " <th>Unnamed: 0</th>\n",
42
+ " <th>Text</th>\n",
43
+ " <th>target</th>\n",
44
+ " </tr>\n",
45
+ " </thead>\n",
46
+ " <tbody>\n",
47
+ " <tr>\n",
48
+ " <th>0</th>\n",
49
+ " <td>0.0</td>\n",
50
+ " <td>polis tangkap</td>\n",
51
+ " <td>NonCyberbully</td>\n",
52
+ " </tr>\n",
53
+ " <tr>\n",
54
+ " <th>1</th>\n",
55
+ " <td>1.0</td>\n",
56
+ " <td>kenapa lokasi kebakaran terlalu spesifik</td>\n",
57
+ " <td>NonCyberbully</td>\n",
58
+ " </tr>\n",
59
+ " <tr>\n",
60
+ " <th>2</th>\n",
61
+ " <td>2.0</td>\n",
62
+ " <td>menyesal tanya nak for birthday</td>\n",
63
+ " <td>NonCyberbully</td>\n",
64
+ " </tr>\n",
65
+ " <tr>\n",
66
+ " <th>3</th>\n",
67
+ " <td>3.0</td>\n",
68
+ " <td>meriah tah</td>\n",
69
+ " <td>NonCyberbully</td>\n",
70
+ " </tr>\n",
71
+ " <tr>\n",
72
+ " <th>4</th>\n",
73
+ " <td>4.0</td>\n",
74
+ " <td>asal bs kelar kerja jam sik kl baru diajak mee...</td>\n",
75
+ " <td>NonCyberbully</td>\n",
76
+ " </tr>\n",
77
+ " </tbody>\n",
78
+ "</table>\n",
79
+ "</div>"
80
+ ],
81
+ "text/plain": [
82
+ " Unnamed: 0 Text \\\n",
83
+ "0 0.0 polis tangkap \n",
84
+ "1 1.0 kenapa lokasi kebakaran terlalu spesifik \n",
85
+ "2 2.0 menyesal tanya nak for birthday \n",
86
+ "3 3.0 meriah tah \n",
87
+ "4 4.0 asal bs kelar kerja jam sik kl baru diajak mee... \n",
88
+ "\n",
89
+ " target \n",
90
+ "0 NonCyberbully \n",
91
+ "1 NonCyberbully \n",
92
+ "2 NonCyberbully \n",
93
+ "3 NonCyberbully \n",
94
+ "4 NonCyberbully "
95
+ ]
96
+ },
97
+ "execution_count": 3,
98
+ "metadata": {},
99
+ "output_type": "execute_result"
100
+ }
101
+ ],
102
+ "source": [
103
+ "import pandas as pd\n",
104
+ "df = pd.read_csv('C:/Users/user/Documents/PSM/BERT_Ver2/Transformers-Text-Classification-BERT-Blog-main/input/Tagged_MixedNew.csv')\n",
105
+ "df.head()"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {
111
+ "id": "fGUtFkVfemit"
112
+ },
113
+ "source": [
114
+ "## Process the data"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {
121
+ "id": "7C3uRWECemiu",
122
+ "outputId": "8e764d84-010d-4e42-987a-af7162627f6e",
123
+ "colab": {
124
+ "referenced_widgets": [
125
+ "042c8b0b8dcf42eb84660c93778d8ea7",
126
+ "4ab6074437a849f79be038b043025283",
127
+ "9aed4d88c18e4e28a1efbbed94331228"
128
+ ]
129
+ }
130
+ },
131
+ "outputs": [
132
+ {
133
+ "data": {
134
+ "application/vnd.jupyter.widget-view+json": {
135
+ "model_id": "042c8b0b8dcf42eb84660c93778d8ea7",
136
+ "version_major": 2,
137
+ "version_minor": 0
138
+ },
139
+ "text/plain": [
140
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/380 [00:00<?, ?B/s]"
141
+ ]
142
+ },
143
+ "metadata": {},
144
+ "output_type": "display_data"
145
+ },
146
+ {
147
+ "name": "stderr",
148
+ "output_type": "stream",
149
+ "text": [
150
+ "C:\\Users\\user\\anaconda3\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\user\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
151
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
152
+ " warnings.warn(message)\n"
153
+ ]
154
+ },
155
+ {
156
+ "data": {
157
+ "application/vnd.jupyter.widget-view+json": {
158
+ "model_id": "4ab6074437a849f79be038b043025283",
159
+ "version_major": 2,
160
+ "version_minor": 0
161
+ },
162
+ "text/plain": [
163
+ "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/233k [00:00<?, ?B/s]"
164
+ ]
165
+ },
166
+ "metadata": {},
167
+ "output_type": "display_data"
168
+ },
169
+ {
170
+ "data": {
171
+ "application/vnd.jupyter.widget-view+json": {
172
+ "model_id": "9aed4d88c18e4e28a1efbbed94331228",
173
+ "version_major": 2,
174
+ "version_minor": 0
175
+ },
176
+ "text/plain": [
177
+ "Downloading (…)cial_tokens_map.json: 0%| | 0.00/125 [00:00<?, ?B/s]"
178
+ ]
179
+ },
180
+ "metadata": {},
181
+ "output_type": "display_data"
182
+ }
183
+ ],
184
+ "source": [
185
+ "#from transformers import BertTokenizer\n",
186
+ "#tokenizer = BertTokenizer.from_pretrained('malay-huggingface/bert-tiny-bahasa-cased')\n",
187
+ "\n",
188
+ "from transformers import AutoTokenizer\n",
189
+ "tokenizer = AutoTokenizer.from_pretrained('mesolitica/bert-base-standard-bahasa-cased')"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "metadata": {
196
+ "id": "Ks3XobW0emiu"
197
+ },
198
+ "outputs": [],
199
+ "source": [
200
+ "import numpy as np\n",
201
+ "from sklearn.model_selection import train_test_split\n",
202
+ "from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score\n",
203
+ "import torch\n",
204
+ "from transformers import TrainingArguments, Trainer\n",
205
+ "from transformers import BertTokenizer, BertForSequenceClassification"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "metadata": {
212
+ "id": "0ZZx6mUdemiv"
213
+ },
214
+ "outputs": [],
215
+ "source": [
216
+ "def process_data(row):\n",
217
+ "\n",
218
+ " text = row['Text']\n",
219
+ " text = str(text)\n",
220
+ " text = ' '.join(text.split())\n",
221
+ "\n",
222
+ " encodings = tokenizer(text, padding=\"max_length\", truncation=True, max_length=128)\n",
223
+ "\n",
224
+ " label = 0\n",
225
+ " if row['target'] == 'Cyberbully':\n",
226
+ " label += 1\n",
227
+ "\n",
228
+ " encodings['label'] = label\n",
229
+ " encodings['Text'] = text\n",
230
+ "\n",
231
+ " return encodings"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {
238
+ "id": "MaFmqSc-emiv",
239
+ "outputId": "03eb6491-b646-45dd-ef3d-318c81313430"
240
+ },
241
+ "outputs": [
242
+ {
243
+ "name": "stdout",
244
+ "output_type": "stream",
245
+ "text": [
246
+ "{'input_ids': [2, 2039, 3058, 9857, 1606, 1164, 2161, 8062, 1219, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label': 0, 'Text': 'Saya suka masakan beliau dan cara penyampaiannya'}\n"
247
+ ]
248
+ }
249
+ ],
250
+ "source": [
251
+ "print(process_data({\n",
252
+ " 'Text': 'Saya suka masakan beliau dan cara penyampaiannya',\n",
253
+ " 'target': 'NonCyberbully'\n",
254
+ "}))"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {
261
+ "id": "Lel-2lqKemiw"
262
+ },
263
+ "outputs": [],
264
+ "source": [
265
+ "processed_data = []\n",
266
+ "\n",
267
+ "for i in range(len(df[:1383])):\n",
268
+ " processed_data.append(process_data(df.iloc[i]))"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "markdown",
273
+ "metadata": {
274
+ "id": "x_DGsKzHemiw"
275
+ },
276
+ "source": [
277
+ "## Generate the dataset"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "metadata": {
284
+ "id": "oc_NsbnXemiw"
285
+ },
286
+ "outputs": [],
287
+ "source": [
288
+ "from sklearn.model_selection import train_test_split\n",
289
+ "\n",
290
+ "new_df = pd.DataFrame(processed_data)\n",
291
+ "\n",
292
+ "train_df, valid_df = train_test_split(\n",
293
+ " new_df,\n",
294
+ " test_size=0.2,\n",
295
+ " random_state=2022\n",
296
+ ")"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": null,
302
+ "metadata": {
303
+ "id": "4qSci5CRemix"
304
+ },
305
+ "outputs": [],
306
+ "source": [
307
+ "import pyarrow as pa\n",
308
+ "from datasets import Dataset\n",
309
+ "\n",
310
+ "train_hg = Dataset(pa.Table.from_pandas(train_df))\n",
311
+ "valid_hg = Dataset(pa.Table.from_pandas(valid_df))"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "metadata": {
318
+ "id": "xDgnim7iemix",
319
+ "outputId": "59858161-59a4-4731-fbfc-7e30a1246eed"
320
+ },
321
+ "outputs": [
322
+ {
323
+ "data": {
324
+ "text/plain": [
325
+ "Dataset({\n",
326
+ " features: ['Text', 'attention_mask', 'input_ids', 'label', 'token_type_ids', '__index_level_0__'],\n",
327
+ " num_rows: 277\n",
328
+ "})"
329
+ ]
330
+ },
331
+ "execution_count": 12,
332
+ "metadata": {},
333
+ "output_type": "execute_result"
334
+ }
335
+ ],
336
+ "source": [
337
+ "valid_hg"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "markdown",
342
+ "metadata": {
343
+ "id": "8Uqq0cKKemiy"
344
+ },
345
+ "source": [
346
+ "## Create a model"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {
353
+ "id": "QQkDAXmRemiz",
354
+ "outputId": "e00faff0-c7d7-456d-dab2-73d9839c0274",
355
+ "colab": {
356
+ "referenced_widgets": [
357
+ "b9faad28a43547029c8b13ab639f8d05",
358
+ "6175ea4206304020823d86e0bbc23298"
359
+ ]
360
+ }
361
+ },
362
+ "outputs": [
363
+ {
364
+ "data": {
365
+ "application/vnd.jupyter.widget-view+json": {
366
+ "model_id": "b9faad28a43547029c8b13ab639f8d05",
367
+ "version_major": 2,
368
+ "version_minor": 0
369
+ },
370
+ "text/plain": [
371
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/697 [00:00<?, ?B/s]"
372
+ ]
373
+ },
374
+ "metadata": {},
375
+ "output_type": "display_data"
376
+ },
377
+ {
378
+ "data": {
379
+ "application/vnd.jupyter.widget-view+json": {
380
+ "model_id": "6175ea4206304020823d86e0bbc23298",
381
+ "version_major": 2,
382
+ "version_minor": 0
383
+ },
384
+ "text/plain": [
385
+ "Downloading pytorch_model.bin: 0%| | 0.00/443M [00:00<?, ?B/s]"
386
+ ]
387
+ },
388
+ "metadata": {},
389
+ "output_type": "display_data"
390
+ },
391
+ {
392
+ "name": "stderr",
393
+ "output_type": "stream",
394
+ "text": [
395
+ "Some weights of the model checkpoint at mesolitica/bert-base-standard-bahasa-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n",
396
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
397
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
398
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at mesolitica/bert-base-standard-bahasa-cased and are newly initialized: ['classifier.bias', 'bert.pooler.dense.bias', 'classifier.weight', 'bert.pooler.dense.weight']\n",
399
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
400
+ ]
401
+ }
402
+ ],
403
+ "source": [
404
+ "#from transformers import BertForSequenceClassification\n",
405
+ "\n",
406
+ "#model = BertForSequenceClassification.from_pretrained(\n",
407
+ "# 'malay-huggingface/bert-tiny-bahasa-cased',\n",
408
+ "# num_labels=2\n",
409
+ "#)\n",
410
+ "\n",
411
+ "\n",
412
+ "from transformers import AutoModelForSequenceClassification\n",
413
+ "\n",
414
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
415
+ " 'mesolitica/bert-base-standard-bahasa-cased',\n",
416
+ " num_labels=2\n",
417
+ ")"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": null,
423
+ "metadata": {
424
+ "id": "ifvtnwBMemi1"
425
+ },
426
+ "outputs": [],
427
+ "source": [
428
+ "def compute_metrics(p):\n",
429
+ " print(type(p))\n",
430
+ " pred, labels = p\n",
431
+ " pred = np.argmax(pred, axis=1)\n",
432
+ "\n",
433
+ " accuracy = accuracy_score(y_true=labels, y_pred=pred)\n",
434
+ " recall = recall_score(y_true=labels, y_pred=pred)\n",
435
+ " precision = precision_score(y_true=labels, y_pred=pred)\n",
436
+ " f1 = f1_score(y_true=labels, y_pred=pred)\n",
437
+ "\n",
438
+ " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n",
439
+ ""
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "code",
444
+ "execution_count": null,
445
+ "metadata": {
446
+ "id": "50Xy9P7Remi2"
447
+ },
448
+ "outputs": [],
449
+ "source": [
450
+ "from transformers import TrainingArguments, Trainer\n",
451
+ "\n",
452
+ "training_args = TrainingArguments(output_dir=\"./result\", evaluation_strategy=\"epoch\")\n",
453
+ "\n",
454
+ "trainer = Trainer(\n",
455
+ " model=model,\n",
456
+ " args=training_args,\n",
457
+ " train_dataset=train_hg,\n",
458
+ " eval_dataset=valid_hg,\n",
459
+ " tokenizer=tokenizer,\n",
460
+ " compute_metrics=compute_metrics\n",
461
+ ")"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "markdown",
466
+ "metadata": {
467
+ "id": "myIstfgJemi3"
468
+ },
469
+ "source": [
470
+ "## Train and Evaluate the model"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {
477
+ "id": "-UtAkNHUemi4",
478
+ "outputId": "5af038f3-a77c-41eb-e48d-747a8e776e38"
479
+ },
480
+ "outputs": [
481
+ {
482
+ "name": "stderr",
483
+ "output_type": "stream",
484
+ "text": [
485
+ "C:\\Users\\user\\anaconda3\\lib\\site-packages\\transformers\\optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
486
+ " warnings.warn(\n",
487
+ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
488
+ ]
489
+ },
490
+ {
491
+ "data": {
492
+ "text/html": [
493
+ "\n",
494
+ " <div>\n",
495
+ " \n",
496
+ " <progress value='417' max='417' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
497
+ " [417/417 56:36, Epoch 3/3]\n",
498
+ " </div>\n",
499
+ " <table border=\"1\" class=\"dataframe\">\n",
500
+ " <thead>\n",
501
+ " <tr style=\"text-align: left;\">\n",
502
+ " <th>Epoch</th>\n",
503
+ " <th>Training Loss</th>\n",
504
+ " <th>Validation Loss</th>\n",
505
+ " <th>Accuracy</th>\n",
506
+ " <th>Precision</th>\n",
507
+ " <th>Recall</th>\n",
508
+ " <th>F1</th>\n",
509
+ " </tr>\n",
510
+ " </thead>\n",
511
+ " <tbody>\n",
512
+ " <tr>\n",
513
+ " <td>1</td>\n",
514
+ " <td>No log</td>\n",
515
+ " <td>0.493876</td>\n",
516
+ " <td>0.779783</td>\n",
517
+ " <td>0.657343</td>\n",
518
+ " <td>0.886792</td>\n",
519
+ " <td>0.755020</td>\n",
520
+ " </tr>\n",
521
+ " <tr>\n",
522
+ " <td>2</td>\n",
523
+ " <td>No log</td>\n",
524
+ " <td>0.542367</td>\n",
525
+ " <td>0.870036</td>\n",
526
+ " <td>0.850000</td>\n",
527
+ " <td>0.801887</td>\n",
528
+ " <td>0.825243</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>3</td>\n",
532
+ " <td>No log</td>\n",
533
+ " <td>0.725669</td>\n",
534
+ " <td>0.848375</td>\n",
535
+ " <td>0.820000</td>\n",
536
+ " <td>0.773585</td>\n",
537
+ " <td>0.796117</td>\n",
538
+ " </tr>\n",
539
+ " </tbody>\n",
540
+ "</table><p>"
541
+ ],
542
+ "text/plain": [
543
+ "<IPython.core.display.HTML object>"
544
+ ]
545
+ },
546
+ "metadata": {},
547
+ "output_type": "display_data"
548
+ },
549
+ {
550
+ "name": "stdout",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n",
554
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n",
555
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n"
556
+ ]
557
+ },
558
+ {
559
+ "data": {
560
+ "text/plain": [
561
+ "TrainOutput(global_step=417, training_loss=0.2771467213436282, metrics={'train_runtime': 3405.0836, 'train_samples_per_second': 0.974, 'train_steps_per_second': 0.122, 'total_flos': 218053287129600.0, 'train_loss': 0.2771467213436282, 'epoch': 3.0})"
562
+ ]
563
+ },
564
+ "execution_count": 16,
565
+ "metadata": {},
566
+ "output_type": "execute_result"
567
+ }
568
+ ],
569
+ "source": [
570
+ "trainer.train()"
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": null,
576
+ "metadata": {
577
+ "id": "fZYGhNyremi4",
578
+ "outputId": "5119c379-d7e9-48f7-9137-d788f99a3731"
579
+ },
580
+ "outputs": [
581
+ {
582
+ "data": {
583
+ "text/html": [
584
+ "\n",
585
+ " <div>\n",
586
+ " \n",
587
+ " <progress value='35' max='35' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
588
+ " [35/35 00:43]\n",
589
+ " </div>\n",
590
+ " "
591
+ ],
592
+ "text/plain": [
593
+ "<IPython.core.display.HTML object>"
594
+ ]
595
+ },
596
+ "metadata": {},
597
+ "output_type": "display_data"
598
+ },
599
+ {
600
+ "name": "stdout",
601
+ "output_type": "stream",
602
+ "text": [
603
+ "<class 'transformers.trainer_utils.EvalPrediction'>\n"
604
+ ]
605
+ },
606
+ {
607
+ "data": {
608
+ "text/plain": [
609
+ "{'eval_loss': 0.7256694436073303,\n",
610
+ " 'eval_accuracy': 0.8483754512635379,\n",
611
+ " 'eval_precision': 0.82,\n",
612
+ " 'eval_recall': 0.7735849056603774,\n",
613
+ " 'eval_f1': 0.796116504854369,\n",
614
+ " 'eval_runtime': 44.9419,\n",
615
+ " 'eval_samples_per_second': 6.164,\n",
616
+ " 'eval_steps_per_second': 0.779,\n",
617
+ " 'epoch': 3.0}"
618
+ ]
619
+ },
620
+ "execution_count": 17,
621
+ "metadata": {},
622
+ "output_type": "execute_result"
623
+ }
624
+ ],
625
+ "source": [
626
+ "trainer.evaluate()"
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "markdown",
631
+ "metadata": {
632
+ "id": "tlw24Ccdemi5"
633
+ },
634
+ "source": [
635
+ "## Save the model"
636
+ ]
637
+ },
638
+ {
639
+ "cell_type": "code",
640
+ "execution_count": null,
641
+ "metadata": {
642
+ "id": "69n4eVBHemi6"
643
+ },
644
+ "outputs": [],
645
+ "source": [
646
+ "model.save_pretrained('./model/')"
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": null,
652
+ "metadata": {
653
+ "id": "gC9qDoERemi6",
654
+ "outputId": "a5514df7-d322-48b9-df27-c799dca6d884"
655
+ },
656
+ "outputs": [
657
+ {
658
+ "name": "stdout",
659
+ "output_type": "stream",
660
+ "text": [
661
+ "Looking in indexes: https://download.pytorch.org/whl/cu117\n",
662
+ "Requirement already satisfied: torch in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.1+cu118)\n",
663
+ "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.15.2+cu117)\n",
664
+ "Requirement already satisfied: torchaudio in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.2+cu117)\n",
665
+ "Requirement already satisfied: sympy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (1.11.1)\n",
666
+ "Requirement already satisfied: jinja2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.1.2)\n",
667
+ "Requirement already satisfied: filelock in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.9.0)\n",
668
+ "Requirement already satisfied: networkx in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (2.5.1)\n",
669
+ "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (4.4.0)\n",
670
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.4.0)\n",
671
+ "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.23.5)\n",
672
+ "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.28.1)\n",
673
+ "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from jinja2->torch) (2.1.1)\n",
674
+ "Requirement already satisfied: decorator<5,>=4.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from networkx->torch) (4.4.2)\n",
675
+ "Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
676
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.14)\n",
677
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.10)\n",
678
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.12.7)\n",
679
+ "Requirement already satisfied: mpmath>=0.19 in c:\\users\\user\\anaconda3\\lib\\site-packages (from sympy->torch) (1.2.1)\n"
680
+ ]
681
+ }
682
+ ],
683
+ "source": [
684
+ "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": null,
690
+ "metadata": {
691
+ "id": "3NBugUKAemi7"
692
+ },
693
+ "outputs": [],
694
+ "source": []
695
+ },
696
+ {
697
+ "cell_type": "code",
698
+ "execution_count": null,
699
+ "metadata": {
700
+ "id": "-W3_K_Kjemi7"
701
+ },
702
+ "outputs": [],
703
+ "source": []
704
+ },
705
+ {
706
+ "cell_type": "markdown",
707
+ "metadata": {
708
+ "id": "yMiT54Ddemi7"
709
+ },
710
+ "source": [
711
+ "## Load the model"
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "code",
716
+ "execution_count": null,
717
+ "metadata": {
718
+ "id": "mEFnUaM3emi7"
719
+ },
720
+ "outputs": [],
721
+ "source": [
722
+ "import torch\n",
723
+ "from transformers import AutoModelForSequenceClassification\n",
724
+ "\n",
725
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
726
+ "\n",
727
+ "new_model = AutoModelForSequenceClassification.from_pretrained('./model/').to(device)"
728
+ ]
729
+ },
730
+ {
731
+ "cell_type": "code",
732
+ "execution_count": null,
733
+ "metadata": {
734
+ "id": "zkDeulcTemi8",
735
+ "outputId": "2500b324-398b-471b-9c08-48fa79ea9de3"
736
+ },
737
+ "outputs": [
738
+ {
739
+ "name": "stderr",
740
+ "output_type": "stream",
741
+ "text": [
742
+ "ERROR: torch-1.0.1-cp36-cp36m-win_amd64.whl is not a supported wheel on this platform.\n",
743
+ "\n",
744
+ "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n",
745
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
746
+ ]
747
+ },
748
+ {
749
+ "name": "stdout",
750
+ "output_type": "stream",
751
+ "text": [
752
+ "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.14.0)\n",
753
+ "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (4.1.1)\n",
754
+ "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.27.1)\n",
755
+ "Requirement already satisfied: torch==1.13.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.13.0)\n",
756
+ "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.24.2)\n",
757
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.0.1)\n",
758
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (3.3)\n",
759
+ "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
760
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.9.24)\n",
761
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.9)\n"
762
+ ]
763
+ },
764
+ {
765
+ "name": "stderr",
766
+ "output_type": "stream",
767
+ "text": [
768
+ "\n",
769
+ "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n",
770
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
771
+ ]
772
+ }
773
+ ],
774
+ "source": [
775
+ "!pip install https://download.pytorch.org/whl/cpu/torch-1.0.1-cp36-cp36m-win_amd64.whl\n",
776
+ "!pip install torchvision"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": null,
782
+ "metadata": {
783
+ "id": "WtI-WDBhemi8"
784
+ },
785
+ "outputs": [],
786
+ "source": [
787
+ "from transformers import AutoTokenizer\n",
788
+ "\n",
789
+ "new_tokenizer = AutoTokenizer.from_pretrained('mesolitica/bert-base-standard-bahasa-cased')"
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "markdown",
794
+ "metadata": {
795
+ "id": "S2X_uPYJemi9"
796
+ },
797
+ "source": [
798
+ "## Get predictions"
799
+ ]
800
+ },
801
+ {
802
+ "cell_type": "code",
803
+ "execution_count": null,
804
+ "metadata": {
805
+ "id": "qXKQEiWxemi9"
806
+ },
807
+ "outputs": [],
808
+ "source": [
809
+ "import torch\n",
810
+ "import numpy as np\n",
811
+ "\n",
812
+ "def get_prediction(text):\n",
813
+ " encoding = new_tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128)\n",
814
+ " encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}\n",
815
+ "\n",
816
+ " outputs = new_model(**encoding)\n",
817
+ "\n",
818
+ " logits = outputs.logits\n",
819
+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
820
+ " sigmoid = torch.nn.Sigmoid()\n",
821
+ " print(sigmoid)\n",
822
+ " probs = sigmoid(logits.squeeze().cpu())\n",
823
+ " probs = probs.detach().numpy()\n",
824
+ " label = np.argmax(probs, axis=-1)\n",
825
+ "\n",
826
+ " if label == 1:\n",
827
+ " return {\n",
828
+ " 'Target': 'Cyberbully',\n",
829
+ " 'probability': probs[1]\n",
830
+ " }\n",
831
+ " else:\n",
832
+ " return {\n",
833
+ " 'Target': 'Not Cyberbully',\n",
834
+ " 'probability': probs[0]\n",
835
+ " }"
836
+ ]
837
+ },
838
+ {
839
+ "cell_type": "code",
840
+ "execution_count": null,
841
+ "metadata": {
842
+ "id": "NcYq4vmVemi9"
843
+ },
844
+ "outputs": [],
845
+ "source": [
846
+ "# dir()"
847
+ ]
848
+ },
849
+ {
850
+ "cell_type": "code",
851
+ "execution_count": null,
852
+ "metadata": {
853
+ "id": "CS_2FfAeemi_",
854
+ "outputId": "106776a5-fced-4329-aa1f-5970a4a71386"
855
+ },
856
+ "outputs": [
857
+ {
858
+ "name": "stdout",
859
+ "output_type": "stream",
860
+ "text": [
861
+ "Sigmoid()\n"
862
+ ]
863
+ },
864
+ {
865
+ "data": {
866
+ "text/plain": [
867
+ "{'Target': 'Cyberbully', 'probability': 0.9651532}"
868
+ ]
869
+ },
870
+ "execution_count": 24,
871
+ "metadata": {},
872
+ "output_type": "execute_result"
873
+ }
874
+ ],
875
+ "source": [
876
+ "get_prediction('Aku malas kerja dengan orang macam ni menyusahkan orang je')"
877
+ ]
878
+ }
879
+ ],
880
+ "metadata": {
881
+ "kernelspec": {
882
+ "display_name": "Python 3 (ipykernel)",
883
+ "language": "python",
884
+ "name": "python3"
885
+ },
886
+ "language_info": {
887
+ "codemirror_mode": {
888
+ "name": "ipython",
889
+ "version": 3
890
+ },
891
+ "file_extension": ".py",
892
+ "mimetype": "text/x-python",
893
+ "name": "python",
894
+ "nbconvert_exporter": "python",
895
+ "pygments_lexer": "ipython3",
896
+ "version": "3.10.9"
897
+ },
898
+ "vscode": {
899
+ "interpreter": {
900
+ "hash": "173fe52379437b78f95c8980b8ee9f2930fd7b56889ab31a72735475ddc10c81"
901
+ }
902
+ },
903
+ "colab": {
904
+ "provenance": []
905
+ }
906
+ },
907
+ "nbformat": 4,
908
+ "nbformat_minor": 0
909
+ }