DeFamy commited on
Commit
d2cc008
·
1 Parent(s): 0af51d7

Delete train_model.ipynb

Browse files
Files changed (1) hide show
  1. train_model.ipynb +0 -909
train_model.ipynb DELETED
@@ -1,909 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "XLhB2j_Hemio"
7
- },
8
- "source": [
9
- "## Read the dataset csv file"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": null,
15
- "metadata": {
16
- "id": "hgYEtrYgemir",
17
- "outputId": "d3ddedc7-8bd7-4ba9-c82e-68e4eb1309c3"
18
- },
19
- "outputs": [
20
- {
21
- "data": {
22
- "text/html": [
23
- "<div>\n",
24
- "<style scoped>\n",
25
- " .dataframe tbody tr th:only-of-type {\n",
26
- " vertical-align: middle;\n",
27
- " }\n",
28
- "\n",
29
- " .dataframe tbody tr th {\n",
30
- " vertical-align: top;\n",
31
- " }\n",
32
- "\n",
33
- " .dataframe thead th {\n",
34
- " text-align: right;\n",
35
- " }\n",
36
- "</style>\n",
37
- "<table border=\"1\" class=\"dataframe\">\n",
38
- " <thead>\n",
39
- " <tr style=\"text-align: right;\">\n",
40
- " <th></th>\n",
41
- " <th>Unnamed: 0</th>\n",
42
- " <th>Text</th>\n",
43
- " <th>target</th>\n",
44
- " </tr>\n",
45
- " </thead>\n",
46
- " <tbody>\n",
47
- " <tr>\n",
48
- " <th>0</th>\n",
49
- " <td>0.0</td>\n",
50
- " <td>polis tangkap</td>\n",
51
- " <td>NonCyberbully</td>\n",
52
- " </tr>\n",
53
- " <tr>\n",
54
- " <th>1</th>\n",
55
- " <td>1.0</td>\n",
56
- " <td>kenapa lokasi kebakaran terlalu spesifik</td>\n",
57
- " <td>NonCyberbully</td>\n",
58
- " </tr>\n",
59
- " <tr>\n",
60
- " <th>2</th>\n",
61
- " <td>2.0</td>\n",
62
- " <td>menyesal tanya nak for birthday</td>\n",
63
- " <td>NonCyberbully</td>\n",
64
- " </tr>\n",
65
- " <tr>\n",
66
- " <th>3</th>\n",
67
- " <td>3.0</td>\n",
68
- " <td>meriah tah</td>\n",
69
- " <td>NonCyberbully</td>\n",
70
- " </tr>\n",
71
- " <tr>\n",
72
- " <th>4</th>\n",
73
- " <td>4.0</td>\n",
74
- " <td>asal bs kelar kerja jam sik kl baru diajak mee...</td>\n",
75
- " <td>NonCyberbully</td>\n",
76
- " </tr>\n",
77
- " </tbody>\n",
78
- "</table>\n",
79
- "</div>"
80
- ],
81
- "text/plain": [
82
- " Unnamed: 0 Text \\\n",
83
- "0 0.0 polis tangkap \n",
84
- "1 1.0 kenapa lokasi kebakaran terlalu spesifik \n",
85
- "2 2.0 menyesal tanya nak for birthday \n",
86
- "3 3.0 meriah tah \n",
87
- "4 4.0 asal bs kelar kerja jam sik kl baru diajak mee... \n",
88
- "\n",
89
- " target \n",
90
- "0 NonCyberbully \n",
91
- "1 NonCyberbully \n",
92
- "2 NonCyberbully \n",
93
- "3 NonCyberbully \n",
94
- "4 NonCyberbully "
95
- ]
96
- },
97
- "execution_count": 3,
98
- "metadata": {},
99
- "output_type": "execute_result"
100
- }
101
- ],
102
- "source": [
103
- "import pandas as pd\n",
104
- "df = pd.read_csv('C:/Users/user/Documents/PSM/BERT_Ver2/Transformers-Text-Classification-BERT-Blog-main/input/Tagged_MixedNew.csv')\n",
105
- "df.head()"
106
- ]
107
- },
108
- {
109
- "cell_type": "markdown",
110
- "metadata": {
111
- "id": "fGUtFkVfemit"
112
- },
113
- "source": [
114
- "## Process the data"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": null,
120
- "metadata": {
121
- "id": "7C3uRWECemiu",
122
- "outputId": "8e764d84-010d-4e42-987a-af7162627f6e",
123
- "colab": {
124
- "referenced_widgets": [
125
- "042c8b0b8dcf42eb84660c93778d8ea7",
126
- "4ab6074437a849f79be038b043025283",
127
- "9aed4d88c18e4e28a1efbbed94331228"
128
- ]
129
- }
130
- },
131
- "outputs": [
132
- {
133
- "data": {
134
- "application/vnd.jupyter.widget-view+json": {
135
- "model_id": "042c8b0b8dcf42eb84660c93778d8ea7",
136
- "version_major": 2,
137
- "version_minor": 0
138
- },
139
- "text/plain": [
140
- "Downloading (…)okenizer_config.json: 0%| | 0.00/380 [00:00<?, ?B/s]"
141
- ]
142
- },
143
- "metadata": {},
144
- "output_type": "display_data"
145
- },
146
- {
147
- "name": "stderr",
148
- "output_type": "stream",
149
- "text": [
150
- "C:\\Users\\user\\anaconda3\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\user\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
151
- "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
152
- " warnings.warn(message)\n"
153
- ]
154
- },
155
- {
156
- "data": {
157
- "application/vnd.jupyter.widget-view+json": {
158
- "model_id": "4ab6074437a849f79be038b043025283",
159
- "version_major": 2,
160
- "version_minor": 0
161
- },
162
- "text/plain": [
163
- "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/233k [00:00<?, ?B/s]"
164
- ]
165
- },
166
- "metadata": {},
167
- "output_type": "display_data"
168
- },
169
- {
170
- "data": {
171
- "application/vnd.jupyter.widget-view+json": {
172
- "model_id": "9aed4d88c18e4e28a1efbbed94331228",
173
- "version_major": 2,
174
- "version_minor": 0
175
- },
176
- "text/plain": [
177
- "Downloading (…)cial_tokens_map.json: 0%| | 0.00/125 [00:00<?, ?B/s]"
178
- ]
179
- },
180
- "metadata": {},
181
- "output_type": "display_data"
182
- }
183
- ],
184
- "source": [
185
- "#from transformers import BertTokenizer\n",
186
- "#tokenizer = BertTokenizer.from_pretrained('malay-huggingface/bert-tiny-bahasa-cased')\n",
187
- "\n",
188
- "from transformers import AutoTokenizer\n",
189
- "tokenizer = AutoTokenizer.from_pretrained('mesolitica/bert-base-standard-bahasa-cased')"
190
- ]
191
- },
192
- {
193
- "cell_type": "code",
194
- "execution_count": null,
195
- "metadata": {
196
- "id": "Ks3XobW0emiu"
197
- },
198
- "outputs": [],
199
- "source": [
200
- "import numpy as np\n",
201
- "from sklearn.model_selection import train_test_split\n",
202
- "from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score\n",
203
- "import torch\n",
204
- "from transformers import TrainingArguments, Trainer\n",
205
- "from transformers import BertTokenizer, BertForSequenceClassification"
206
- ]
207
- },
208
- {
209
- "cell_type": "code",
210
- "execution_count": null,
211
- "metadata": {
212
- "id": "0ZZx6mUdemiv"
213
- },
214
- "outputs": [],
215
- "source": [
216
- "def process_data(row):\n",
217
- "\n",
218
- " text = row['Text']\n",
219
- " text = str(text)\n",
220
- " text = ' '.join(text.split())\n",
221
- "\n",
222
- " encodings = tokenizer(text, padding=\"max_length\", truncation=True, max_length=128)\n",
223
- "\n",
224
- " label = 0\n",
225
- " if row['target'] == 'Cyberbully':\n",
226
- " label += 1\n",
227
- "\n",
228
- " encodings['label'] = label\n",
229
- " encodings['Text'] = text\n",
230
- "\n",
231
- " return encodings"
232
- ]
233
- },
234
- {
235
- "cell_type": "code",
236
- "execution_count": null,
237
- "metadata": {
238
- "id": "MaFmqSc-emiv",
239
- "outputId": "03eb6491-b646-45dd-ef3d-318c81313430"
240
- },
241
- "outputs": [
242
- {
243
- "name": "stdout",
244
- "output_type": "stream",
245
- "text": [
246
- "{'input_ids': [2, 2039, 3058, 9857, 1606, 1164, 2161, 8062, 1219, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label': 0, 'Text': 'Saya suka masakan beliau dan cara penyampaiannya'}\n"
247
- ]
248
- }
249
- ],
250
- "source": [
251
- "print(process_data({\n",
252
- " 'Text': 'Saya suka masakan beliau dan cara penyampaiannya',\n",
253
- " 'target': 'NonCyberbully'\n",
254
- "}))"
255
- ]
256
- },
257
- {
258
- "cell_type": "code",
259
- "execution_count": null,
260
- "metadata": {
261
- "id": "Lel-2lqKemiw"
262
- },
263
- "outputs": [],
264
- "source": [
265
- "processed_data = []\n",
266
- "\n",
267
- "for i in range(len(df[:1383])):\n",
268
- " processed_data.append(process_data(df.iloc[i]))"
269
- ]
270
- },
271
- {
272
- "cell_type": "markdown",
273
- "metadata": {
274
- "id": "x_DGsKzHemiw"
275
- },
276
- "source": [
277
- "## Generate the dataset"
278
- ]
279
- },
280
- {
281
- "cell_type": "code",
282
- "execution_count": null,
283
- "metadata": {
284
- "id": "oc_NsbnXemiw"
285
- },
286
- "outputs": [],
287
- "source": [
288
- "from sklearn.model_selection import train_test_split\n",
289
- "\n",
290
- "new_df = pd.DataFrame(processed_data)\n",
291
- "\n",
292
- "train_df, valid_df = train_test_split(\n",
293
- " new_df,\n",
294
- " test_size=0.2,\n",
295
- " random_state=2022\n",
296
- ")"
297
- ]
298
- },
299
- {
300
- "cell_type": "code",
301
- "execution_count": null,
302
- "metadata": {
303
- "id": "4qSci5CRemix"
304
- },
305
- "outputs": [],
306
- "source": [
307
- "import pyarrow as pa\n",
308
- "from datasets import Dataset\n",
309
- "\n",
310
- "train_hg = Dataset(pa.Table.from_pandas(train_df))\n",
311
- "valid_hg = Dataset(pa.Table.from_pandas(valid_df))"
312
- ]
313
- },
314
- {
315
- "cell_type": "code",
316
- "execution_count": null,
317
- "metadata": {
318
- "id": "xDgnim7iemix",
319
- "outputId": "59858161-59a4-4731-fbfc-7e30a1246eed"
320
- },
321
- "outputs": [
322
- {
323
- "data": {
324
- "text/plain": [
325
- "Dataset({\n",
326
- " features: ['Text', 'attention_mask', 'input_ids', 'label', 'token_type_ids', '__index_level_0__'],\n",
327
- " num_rows: 277\n",
328
- "})"
329
- ]
330
- },
331
- "execution_count": 12,
332
- "metadata": {},
333
- "output_type": "execute_result"
334
- }
335
- ],
336
- "source": [
337
- "valid_hg"
338
- ]
339
- },
340
- {
341
- "cell_type": "markdown",
342
- "metadata": {
343
- "id": "8Uqq0cKKemiy"
344
- },
345
- "source": [
346
- "## Create a model"
347
- ]
348
- },
349
- {
350
- "cell_type": "code",
351
- "execution_count": null,
352
- "metadata": {
353
- "id": "QQkDAXmRemiz",
354
- "outputId": "e00faff0-c7d7-456d-dab2-73d9839c0274",
355
- "colab": {
356
- "referenced_widgets": [
357
- "b9faad28a43547029c8b13ab639f8d05",
358
- "6175ea4206304020823d86e0bbc23298"
359
- ]
360
- }
361
- },
362
- "outputs": [
363
- {
364
- "data": {
365
- "application/vnd.jupyter.widget-view+json": {
366
- "model_id": "b9faad28a43547029c8b13ab639f8d05",
367
- "version_major": 2,
368
- "version_minor": 0
369
- },
370
- "text/plain": [
371
- "Downloading (…)lve/main/config.json: 0%| | 0.00/697 [00:00<?, ?B/s]"
372
- ]
373
- },
374
- "metadata": {},
375
- "output_type": "display_data"
376
- },
377
- {
378
- "data": {
379
- "application/vnd.jupyter.widget-view+json": {
380
- "model_id": "6175ea4206304020823d86e0bbc23298",
381
- "version_major": 2,
382
- "version_minor": 0
383
- },
384
- "text/plain": [
385
- "Downloading pytorch_model.bin: 0%| | 0.00/443M [00:00<?, ?B/s]"
386
- ]
387
- },
388
- "metadata": {},
389
- "output_type": "display_data"
390
- },
391
- {
392
- "name": "stderr",
393
- "output_type": "stream",
394
- "text": [
395
- "Some weights of the model checkpoint at mesolitica/bert-base-standard-bahasa-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n",
396
- "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
397
- "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
398
- "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at mesolitica/bert-base-standard-bahasa-cased and are newly initialized: ['classifier.bias', 'bert.pooler.dense.bias', 'classifier.weight', 'bert.pooler.dense.weight']\n",
399
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
400
- ]
401
- }
402
- ],
403
- "source": [
404
- "#from transformers import BertForSequenceClassification\n",
405
- "\n",
406
- "#model = BertForSequenceClassification.from_pretrained(\n",
407
- "# 'malay-huggingface/bert-tiny-bahasa-cased',\n",
408
- "# num_labels=2\n",
409
- "#)\n",
410
- "\n",
411
- "\n",
412
- "from transformers import AutoModelForSequenceClassification\n",
413
- "\n",
414
- "model = AutoModelForSequenceClassification.from_pretrained(\n",
415
- " 'mesolitica/bert-base-standard-bahasa-cased',\n",
416
- " num_labels=2\n",
417
- ")"
418
- ]
419
- },
420
- {
421
- "cell_type": "code",
422
- "execution_count": null,
423
- "metadata": {
424
- "id": "ifvtnwBMemi1"
425
- },
426
- "outputs": [],
427
- "source": [
428
- "def compute_metrics(p):\n",
429
- " print(type(p))\n",
430
- " pred, labels = p\n",
431
- " pred = np.argmax(pred, axis=1)\n",
432
- "\n",
433
- " accuracy = accuracy_score(y_true=labels, y_pred=pred)\n",
434
- " recall = recall_score(y_true=labels, y_pred=pred)\n",
435
- " precision = precision_score(y_true=labels, y_pred=pred)\n",
436
- " f1 = f1_score(y_true=labels, y_pred=pred)\n",
437
- "\n",
438
- " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}\n",
439
- ""
440
- ]
441
- },
442
- {
443
- "cell_type": "code",
444
- "execution_count": null,
445
- "metadata": {
446
- "id": "50Xy9P7Remi2"
447
- },
448
- "outputs": [],
449
- "source": [
450
- "from transformers import TrainingArguments, Trainer\n",
451
- "\n",
452
- "training_args = TrainingArguments(output_dir=\"./result\", evaluation_strategy=\"epoch\")\n",
453
- "\n",
454
- "trainer = Trainer(\n",
455
- " model=model,\n",
456
- " args=training_args,\n",
457
- " train_dataset=train_hg,\n",
458
- " eval_dataset=valid_hg,\n",
459
- " tokenizer=tokenizer,\n",
460
- " compute_metrics=compute_metrics\n",
461
- ")"
462
- ]
463
- },
464
- {
465
- "cell_type": "markdown",
466
- "metadata": {
467
- "id": "myIstfgJemi3"
468
- },
469
- "source": [
470
- "## Train and Evaluate the model"
471
- ]
472
- },
473
- {
474
- "cell_type": "code",
475
- "execution_count": null,
476
- "metadata": {
477
- "id": "-UtAkNHUemi4",
478
- "outputId": "5af038f3-a77c-41eb-e48d-747a8e776e38"
479
- },
480
- "outputs": [
481
- {
482
- "name": "stderr",
483
- "output_type": "stream",
484
- "text": [
485
- "C:\\Users\\user\\anaconda3\\lib\\site-packages\\transformers\\optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
486
- " warnings.warn(\n",
487
- "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
488
- ]
489
- },
490
- {
491
- "data": {
492
- "text/html": [
493
- "\n",
494
- " <div>\n",
495
- " \n",
496
- " <progress value='417' max='417' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
497
- " [417/417 56:36, Epoch 3/3]\n",
498
- " </div>\n",
499
- " <table border=\"1\" class=\"dataframe\">\n",
500
- " <thead>\n",
501
- " <tr style=\"text-align: left;\">\n",
502
- " <th>Epoch</th>\n",
503
- " <th>Training Loss</th>\n",
504
- " <th>Validation Loss</th>\n",
505
- " <th>Accuracy</th>\n",
506
- " <th>Precision</th>\n",
507
- " <th>Recall</th>\n",
508
- " <th>F1</th>\n",
509
- " </tr>\n",
510
- " </thead>\n",
511
- " <tbody>\n",
512
- " <tr>\n",
513
- " <td>1</td>\n",
514
- " <td>No log</td>\n",
515
- " <td>0.493876</td>\n",
516
- " <td>0.779783</td>\n",
517
- " <td>0.657343</td>\n",
518
- " <td>0.886792</td>\n",
519
- " <td>0.755020</td>\n",
520
- " </tr>\n",
521
- " <tr>\n",
522
- " <td>2</td>\n",
523
- " <td>No log</td>\n",
524
- " <td>0.542367</td>\n",
525
- " <td>0.870036</td>\n",
526
- " <td>0.850000</td>\n",
527
- " <td>0.801887</td>\n",
528
- " <td>0.825243</td>\n",
529
- " </tr>\n",
530
- " <tr>\n",
531
- " <td>3</td>\n",
532
- " <td>No log</td>\n",
533
- " <td>0.725669</td>\n",
534
- " <td>0.848375</td>\n",
535
- " <td>0.820000</td>\n",
536
- " <td>0.773585</td>\n",
537
- " <td>0.796117</td>\n",
538
- " </tr>\n",
539
- " </tbody>\n",
540
- "</table><p>"
541
- ],
542
- "text/plain": [
543
- "<IPython.core.display.HTML object>"
544
- ]
545
- },
546
- "metadata": {},
547
- "output_type": "display_data"
548
- },
549
- {
550
- "name": "stdout",
551
- "output_type": "stream",
552
- "text": [
553
- "<class 'transformers.trainer_utils.EvalPrediction'>\n",
554
- "<class 'transformers.trainer_utils.EvalPrediction'>\n",
555
- "<class 'transformers.trainer_utils.EvalPrediction'>\n"
556
- ]
557
- },
558
- {
559
- "data": {
560
- "text/plain": [
561
- "TrainOutput(global_step=417, training_loss=0.2771467213436282, metrics={'train_runtime': 3405.0836, 'train_samples_per_second': 0.974, 'train_steps_per_second': 0.122, 'total_flos': 218053287129600.0, 'train_loss': 0.2771467213436282, 'epoch': 3.0})"
562
- ]
563
- },
564
- "execution_count": 16,
565
- "metadata": {},
566
- "output_type": "execute_result"
567
- }
568
- ],
569
- "source": [
570
- "trainer.train()"
571
- ]
572
- },
573
- {
574
- "cell_type": "code",
575
- "execution_count": null,
576
- "metadata": {
577
- "id": "fZYGhNyremi4",
578
- "outputId": "5119c379-d7e9-48f7-9137-d788f99a3731"
579
- },
580
- "outputs": [
581
- {
582
- "data": {
583
- "text/html": [
584
- "\n",
585
- " <div>\n",
586
- " \n",
587
- " <progress value='35' max='35' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
588
- " [35/35 00:43]\n",
589
- " </div>\n",
590
- " "
591
- ],
592
- "text/plain": [
593
- "<IPython.core.display.HTML object>"
594
- ]
595
- },
596
- "metadata": {},
597
- "output_type": "display_data"
598
- },
599
- {
600
- "name": "stdout",
601
- "output_type": "stream",
602
- "text": [
603
- "<class 'transformers.trainer_utils.EvalPrediction'>\n"
604
- ]
605
- },
606
- {
607
- "data": {
608
- "text/plain": [
609
- "{'eval_loss': 0.7256694436073303,\n",
610
- " 'eval_accuracy': 0.8483754512635379,\n",
611
- " 'eval_precision': 0.82,\n",
612
- " 'eval_recall': 0.7735849056603774,\n",
613
- " 'eval_f1': 0.796116504854369,\n",
614
- " 'eval_runtime': 44.9419,\n",
615
- " 'eval_samples_per_second': 6.164,\n",
616
- " 'eval_steps_per_second': 0.779,\n",
617
- " 'epoch': 3.0}"
618
- ]
619
- },
620
- "execution_count": 17,
621
- "metadata": {},
622
- "output_type": "execute_result"
623
- }
624
- ],
625
- "source": [
626
- "trainer.evaluate()"
627
- ]
628
- },
629
- {
630
- "cell_type": "markdown",
631
- "metadata": {
632
- "id": "tlw24Ccdemi5"
633
- },
634
- "source": [
635
- "## Save the model"
636
- ]
637
- },
638
- {
639
- "cell_type": "code",
640
- "execution_count": null,
641
- "metadata": {
642
- "id": "69n4eVBHemi6"
643
- },
644
- "outputs": [],
645
- "source": [
646
- "model.save_pretrained('./model/')"
647
- ]
648
- },
649
- {
650
- "cell_type": "code",
651
- "execution_count": null,
652
- "metadata": {
653
- "id": "gC9qDoERemi6",
654
- "outputId": "a5514df7-d322-48b9-df27-c799dca6d884"
655
- },
656
- "outputs": [
657
- {
658
- "name": "stdout",
659
- "output_type": "stream",
660
- "text": [
661
- "Looking in indexes: https://download.pytorch.org/whl/cu117\n",
662
- "Requirement already satisfied: torch in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.1+cu118)\n",
663
- "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.15.2+cu117)\n",
664
- "Requirement already satisfied: torchaudio in c:\\users\\user\\anaconda3\\lib\\site-packages (2.0.2+cu117)\n",
665
- "Requirement already satisfied: sympy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (1.11.1)\n",
666
- "Requirement already satisfied: jinja2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.1.2)\n",
667
- "Requirement already satisfied: filelock in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (3.9.0)\n",
668
- "Requirement already satisfied: networkx in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (2.5.1)\n",
669
- "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torch) (4.4.0)\n",
670
- "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.4.0)\n",
671
- "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.23.5)\n",
672
- "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.28.1)\n",
673
- "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from jinja2->torch) (2.1.1)\n",
674
- "Requirement already satisfied: decorator<5,>=4.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from networkx->torch) (4.4.2)\n",
675
- "Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
676
- "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.14)\n",
677
- "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.10)\n",
678
- "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.12.7)\n",
679
- "Requirement already satisfied: mpmath>=0.19 in c:\\users\\user\\anaconda3\\lib\\site-packages (from sympy->torch) (1.2.1)\n"
680
- ]
681
- }
682
- ],
683
- "source": [
684
- "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117"
685
- ]
686
- },
687
- {
688
- "cell_type": "code",
689
- "execution_count": null,
690
- "metadata": {
691
- "id": "3NBugUKAemi7"
692
- },
693
- "outputs": [],
694
- "source": []
695
- },
696
- {
697
- "cell_type": "code",
698
- "execution_count": null,
699
- "metadata": {
700
- "id": "-W3_K_Kjemi7"
701
- },
702
- "outputs": [],
703
- "source": []
704
- },
705
- {
706
- "cell_type": "markdown",
707
- "metadata": {
708
- "id": "yMiT54Ddemi7"
709
- },
710
- "source": [
711
- "## Load the model"
712
- ]
713
- },
714
- {
715
- "cell_type": "code",
716
- "execution_count": null,
717
- "metadata": {
718
- "id": "mEFnUaM3emi7"
719
- },
720
- "outputs": [],
721
- "source": [
722
- "import torch\n",
723
- "from transformers import AutoModelForSequenceClassification\n",
724
- "\n",
725
- "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
726
- "\n",
727
- "new_model = AutoModelForSequenceClassification.from_pretrained('./model/').to(device)"
728
- ]
729
- },
730
- {
731
- "cell_type": "code",
732
- "execution_count": null,
733
- "metadata": {
734
- "id": "zkDeulcTemi8",
735
- "outputId": "2500b324-398b-471b-9c08-48fa79ea9de3"
736
- },
737
- "outputs": [
738
- {
739
- "name": "stderr",
740
- "output_type": "stream",
741
- "text": [
742
- "ERROR: torch-1.0.1-cp36-cp36m-win_amd64.whl is not a supported wheel on this platform.\n",
743
- "\n",
744
- "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n",
745
- "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
746
- ]
747
- },
748
- {
749
- "name": "stdout",
750
- "output_type": "stream",
751
- "text": [
752
- "Requirement already satisfied: torchvision in c:\\users\\user\\anaconda3\\lib\\site-packages (0.14.0)\n",
753
- "Requirement already satisfied: typing-extensions in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (4.1.1)\n",
754
- "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (2.27.1)\n",
755
- "Requirement already satisfied: torch==1.13.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.13.0)\n",
756
- "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (1.24.2)\n",
757
- "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from torchvision) (9.0.1)\n",
758
- "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (3.3)\n",
759
- "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2.0.4)\n",
760
- "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (2022.9.24)\n",
761
- "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->torchvision) (1.26.9)\n"
762
- ]
763
- },
764
- {
765
- "name": "stderr",
766
- "output_type": "stream",
767
- "text": [
768
- "\n",
769
- "[notice] A new release of pip is available: 23.0.1 -> 23.1.2\n",
770
- "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
771
- ]
772
- }
773
- ],
774
- "source": [
775
- "!pip install https://download.pytorch.org/whl/cpu/torch-1.0.1-cp36-cp36m-win_amd64.whl\n",
776
- "!pip install torchvision"
777
- ]
778
- },
779
- {
780
- "cell_type": "code",
781
- "execution_count": null,
782
- "metadata": {
783
- "id": "WtI-WDBhemi8"
784
- },
785
- "outputs": [],
786
- "source": [
787
- "from transformers import AutoTokenizer\n",
788
- "\n",
789
- "new_tokenizer = AutoTokenizer.from_pretrained('mesolitica/bert-base-standard-bahasa-cased')"
790
- ]
791
- },
792
- {
793
- "cell_type": "markdown",
794
- "metadata": {
795
- "id": "S2X_uPYJemi9"
796
- },
797
- "source": [
798
- "## Get predictions"
799
- ]
800
- },
801
- {
802
- "cell_type": "code",
803
- "execution_count": null,
804
- "metadata": {
805
- "id": "qXKQEiWxemi9"
806
- },
807
- "outputs": [],
808
- "source": [
809
- "import torch\n",
810
- "import numpy as np\n",
811
- "\n",
812
- "def get_prediction(text):\n",
813
- " encoding = new_tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128)\n",
814
- " encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}\n",
815
- "\n",
816
- " outputs = new_model(**encoding)\n",
817
- "\n",
818
- " logits = outputs.logits\n",
819
- " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
820
- " sigmoid = torch.nn.Sigmoid()\n",
821
- " print(sigmoid)\n",
822
- " probs = sigmoid(logits.squeeze().cpu())\n",
823
- " probs = probs.detach().numpy()\n",
824
- " label = np.argmax(probs, axis=-1)\n",
825
- "\n",
826
- " if label == 1:\n",
827
- " return {\n",
828
- " 'Target': 'Cyberbully',\n",
829
- " 'probability': probs[1]\n",
830
- " }\n",
831
- " else:\n",
832
- " return {\n",
833
- " 'Target': 'Not Cyberbully',\n",
834
- " 'probability': probs[0]\n",
835
- " }"
836
- ]
837
- },
838
- {
839
- "cell_type": "code",
840
- "execution_count": null,
841
- "metadata": {
842
- "id": "NcYq4vmVemi9"
843
- },
844
- "outputs": [],
845
- "source": [
846
- "# dir()"
847
- ]
848
- },
849
- {
850
- "cell_type": "code",
851
- "execution_count": null,
852
- "metadata": {
853
- "id": "CS_2FfAeemi_",
854
- "outputId": "106776a5-fced-4329-aa1f-5970a4a71386"
855
- },
856
- "outputs": [
857
- {
858
- "name": "stdout",
859
- "output_type": "stream",
860
- "text": [
861
- "Sigmoid()\n"
862
- ]
863
- },
864
- {
865
- "data": {
866
- "text/plain": [
867
- "{'Target': 'Cyberbully', 'probability': 0.9651532}"
868
- ]
869
- },
870
- "execution_count": 24,
871
- "metadata": {},
872
- "output_type": "execute_result"
873
- }
874
- ],
875
- "source": [
876
- "get_prediction('Aku malas kerja dengan orang macam ni menyusahkan orang je')"
877
- ]
878
- }
879
- ],
880
- "metadata": {
881
- "kernelspec": {
882
- "display_name": "Python 3 (ipykernel)",
883
- "language": "python",
884
- "name": "python3"
885
- },
886
- "language_info": {
887
- "codemirror_mode": {
888
- "name": "ipython",
889
- "version": 3
890
- },
891
- "file_extension": ".py",
892
- "mimetype": "text/x-python",
893
- "name": "python",
894
- "nbconvert_exporter": "python",
895
- "pygments_lexer": "ipython3",
896
- "version": "3.10.9"
897
- },
898
- "vscode": {
899
- "interpreter": {
900
- "hash": "173fe52379437b78f95c8980b8ee9f2930fd7b56889ab31a72735475ddc10c81"
901
- }
902
- },
903
- "colab": {
904
- "provenance": []
905
- }
906
- },
907
- "nbformat": 4,
908
- "nbformat_minor": 0
909
- }