prithivMLmods commited on
Commit
d7f7353
·
verified ·
1 Parent(s): 6f083ab

Delete Builder Script

Browse files
Builder Script/builder.script.trainner.ipynb DELETED
@@ -1,565 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "97b4efc3-1879-4441-af52-de470fbc3ae8",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "!pip install -q evaluate datasets accelerate\n",
11
- "!pip install -q transformers\n",
12
- "!pip install -q huggingface_hub"
13
- ]
14
- },
15
- {
16
- "cell_type": "code",
17
- "execution_count": null,
18
- "id": "ae923886-86f3-431d-b701-1200110b429c",
19
- "metadata": {},
20
- "outputs": [],
21
- "source": [
22
- "!pip install -q imbalanced-learn\n",
23
- "#Skip the installation if your runtime is in Google Colab notebooks."
24
- ]
25
- },
26
- {
27
- "cell_type": "code",
28
- "execution_count": null,
29
- "id": "126923c7-d53f-42d8-8f06-2ea05609ab0e",
30
- "metadata": {},
31
- "outputs": [],
32
- "source": [
33
- "!pip install -q numpy\n",
34
- "#Skip the installation if your runtime is in Google Colab notebooks."
35
- ]
36
- },
37
- {
38
- "cell_type": "code",
39
- "execution_count": null,
40
- "id": "9e628805-b90b-4b98-ae97-9f8a8142767f",
41
- "metadata": {},
42
- "outputs": [],
43
- "source": [
44
- "!pip install -q pillow==11.0.0\n",
45
- "#Skip the installation if your runtime is in Google Colab notebooks."
46
- ]
47
- },
48
- {
49
- "cell_type": "code",
50
- "execution_count": null,
51
- "id": "b58fab4c-211f-4b7b-b7c4-dd76e20c1beb",
52
- "metadata": {},
53
- "outputs": [],
54
- "source": [
55
- "!pip install -q torchvision \n",
56
- "#Skip the installation if your runtime is in Google Colab notebooks."
57
- ]
58
- },
59
- {
60
- "cell_type": "code",
61
- "execution_count": null,
62
- "id": "d7454ffa-885e-44ba-8259-d8c45f8ec72b",
63
- "metadata": {},
64
- "outputs": [],
65
- "source": [
66
- "!pip install -q matplotlib\n",
67
- "!pip install -q scikit-learn\n",
68
- "#Skip the installation if your runtime is in Google Colab notebooks."
69
- ]
70
- },
71
- {
72
- "cell_type": "code",
73
- "execution_count": null,
74
- "id": "4987ed31-c012-434b-9ea7-78da17061d5d",
75
- "metadata": {},
76
- "outputs": [],
77
- "source": [
78
- "import warnings\n",
79
- "warnings.filterwarnings(\"ignore\")\n",
80
- "\n",
81
- "import gc\n",
82
- "import numpy as np\n",
83
- "import pandas as pd\n",
84
- "import itertools\n",
85
- "from collections import Counter\n",
86
- "import matplotlib.pyplot as plt\n",
87
- "from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, f1_score\n",
88
- "from imblearn.over_sampling import RandomOverSampler\n",
89
- "import evaluate\n",
90
- "from datasets import Dataset, Image, ClassLabel\n",
91
- "from transformers import (\n",
92
- " TrainingArguments,\n",
93
- " Trainer,\n",
94
- " ViTImageProcessor,\n",
95
- " ViTForImageClassification,\n",
96
- " DefaultDataCollator\n",
97
- ")\n",
98
- "import torch\n",
99
- "from torch.utils.data import DataLoader\n",
100
- "from torchvision.transforms import (\n",
101
- " CenterCrop,\n",
102
- " Compose,\n",
103
- " Normalize,\n",
104
- " RandomRotation,\n",
105
- " RandomResizedCrop,\n",
106
- " RandomHorizontalFlip,\n",
107
- " RandomAdjustSharpness,\n",
108
- " Resize,\n",
109
- " ToTensor\n",
110
- ")\n",
111
- "\n",
112
- "#.......................................................................\n",
113
- "\n",
114
- "#Retain this part if you're working outside Google Colab notebooks.\n",
115
- "from PIL import Image, ExifTags\n",
116
- "\n",
117
- "#.......................................................................\n",
118
- "\n",
119
- "from PIL import Image as PILImage\n",
120
- "from PIL import ImageFile\n",
121
- "# Enable loading truncated images\n",
122
- "ImageFile.LOAD_TRUNCATED_IMAGES = True"
123
- ]
124
- },
125
- {
126
- "cell_type": "code",
127
- "execution_count": null,
128
- "id": "236bc802-54ba-44d1-b35b-62f548832935",
129
- "metadata": {},
130
- "outputs": [],
131
- "source": [
132
- "from datasets import load_dataset\n",
133
- "dataset = load_dataset(\"--your--dataset--goes--here--\", split=\"train\")"
134
- ]
135
- },
136
- {
137
- "cell_type": "code",
138
- "execution_count": null,
139
- "id": "d57e17cc-72b2-4fde-9855-751cf3440624",
140
- "metadata": {},
141
- "outputs": [],
142
- "source": [
143
- "from pathlib import Path\n",
144
- "\n",
145
- "file_names = []\n",
146
- "labels = []\n",
147
- "\n",
148
- "for example in dataset:\n",
149
- " file_path = str(example['image']) \n",
150
- " label = example['label'] \n",
151
- "\n",
152
- " file_names.append(file_path) \n",
153
- " labels.append(label) \n",
154
- "\n",
155
- "print(len(file_names), len(labels))"
156
- ]
157
- },
158
- {
159
- "cell_type": "code",
160
- "execution_count": null,
161
- "id": "e52c85d2-a245-47c5-9403-5a9cf4e4269d",
162
- "metadata": {},
163
- "outputs": [],
164
- "source": [
165
- "df = pd.DataFrame.from_dict({\"image\": file_names, \"label\": labels})\n",
166
- "print(df.shape)"
167
- ]
168
- },
169
- {
170
- "cell_type": "code",
171
- "execution_count": null,
172
- "id": "beba86dd-0605-4ebf-8ebb-97d6ad9e5edd",
173
- "metadata": {},
174
- "outputs": [],
175
- "source": [
176
- "df.head()\n",
177
- "df['label'].unique()"
178
- ]
179
- },
180
- {
181
- "cell_type": "code",
182
- "execution_count": null,
183
- "id": "6defc1e9-4f46-49b6-addc-f422c38fe7e8",
184
- "metadata": {},
185
- "outputs": [],
186
- "source": [
187
- "y = df[['label']]\n",
188
- "df = df.drop(['label'], axis=1)\n",
189
- "ros = RandomOverSampler(random_state=83)\n",
190
- "df, y_resampled = ros.fit_resample(df, y)\n",
191
- "del y\n",
192
- "df['label'] = y_resampled\n",
193
- "del y_resampled\n",
194
- "gc.collect()"
195
- ]
196
- },
197
- {
198
- "cell_type": "code",
199
- "execution_count": null,
200
- "id": "129d278c-3899-49d2-b06f-a0b2f22f4c4e",
201
- "metadata": {},
202
- "outputs": [],
203
- "source": [
204
- "dataset[0][\"image\"]\n",
205
- "dataset[99][\"image\"]"
206
- ]
207
- },
208
- {
209
- "cell_type": "code",
210
- "execution_count": null,
211
- "id": "bffc8755-c4ac-41be-b8ab-f9a6e0dbcca3",
212
- "metadata": {},
213
- "outputs": [],
214
- "source": [
215
- "labels_subset = labels[:5]\n",
216
- "print(labels_subset)"
217
- ]
218
- },
219
- {
220
- "cell_type": "code",
221
- "execution_count": null,
222
- "id": "d003f439-09d1-41e6-9f34-213c4ee38593",
223
- "metadata": {},
224
- "outputs": [],
225
- "source": [
226
- "labels_list = ['Issue In Deepfake', 'High Quality Deepfake']\n",
227
- "\n",
228
- "label2id, id2label = {}, {}\n",
229
- "for i, label in enumerate(labels_list):\n",
230
- " label2id[label] = i\n",
231
- " id2label[i] = label\n",
232
- "\n",
233
- "ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)\n",
234
- "\n",
235
- "print(\"Mapping of IDs to Labels:\", id2label, '\\n')\n",
236
- "print(\"Mapping of Labels to IDs:\", label2id)"
237
- ]
238
- },
239
- {
240
- "cell_type": "code",
241
- "execution_count": null,
242
- "id": "2fbf1f1b-5936-48be-bc99-6897fea94794",
243
- "metadata": {},
244
- "outputs": [],
245
- "source": [
246
- "def map_label2id(example):\n",
247
- " example['label'] = ClassLabels.str2int(example['label'])\n",
248
- " return example\n",
249
- "\n",
250
- "dataset = dataset.map(map_label2id, batched=True)\n",
251
- "\n",
252
- "dataset = dataset.cast_column('label', ClassLabels)\n",
253
- "\n",
254
- "dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column=\"label\")\n",
255
- "\n",
256
- "train_data = dataset['train']\n",
257
- "\n",
258
- "test_data = dataset['test']"
259
- ]
260
- },
261
- {
262
- "cell_type": "code",
263
- "execution_count": null,
264
- "id": "d8a4f7ca-4dff-4446-acaf-f3e7630b678d",
265
- "metadata": {},
266
- "outputs": [],
267
- "source": [
268
- "model_str = \"google/vit-base-patch16-224-in21k\"\n",
269
- "processor = ViTImageProcessor.from_pretrained(model_str)\n",
270
- "\n",
271
- "image_mean, image_std = processor.image_mean, processor.image_std\n",
272
- "size = processor.size[\"height\"]\n",
273
- "\n",
274
- "_train_transforms = Compose(\n",
275
- " [\n",
276
- " Resize((size, size)),\n",
277
- " RandomRotation(90),\n",
278
- " RandomAdjustSharpness(2),\n",
279
- " ToTensor(),\n",
280
- " Normalize(mean=image_mean, std=image_std)\n",
281
- " ]\n",
282
- ")\n",
283
- "\n",
284
- "_val_transforms = Compose(\n",
285
- " [\n",
286
- " Resize((size, size)),\n",
287
- " ToTensor(),\n",
288
- " Normalize(mean=image_mean, std=image_std)\n",
289
- " ]\n",
290
- ")\n",
291
- "\n",
292
- "def train_transforms(examples):\n",
293
- " examples['pixel_values'] = [_train_transforms(image.convert(\"RGB\")) for image in examples['image']]\n",
294
- " return examples\n",
295
- "\n",
296
- "def val_transforms(examples):\n",
297
- " examples['pixel_values'] = [_val_transforms(image.convert(\"RGB\")) for image in examples['image']]\n",
298
- " return examples\n",
299
- "\n",
300
- "train_data.set_transform(train_transforms)\n",
301
- "test_data.set_transform(val_transforms)"
302
- ]
303
- },
304
- {
305
- "cell_type": "code",
306
- "execution_count": null,
307
- "id": "0c8a93ca-e4ff-42e2-b58d-445afa0cfee0",
308
- "metadata": {},
309
- "outputs": [],
310
- "source": [
311
- "def collate_fn(examples):\n",
312
- " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
313
- " labels = torch.tensor([example['label'] for example in examples])\n",
314
- " return {\"pixel_values\": pixel_values, \"labels\": labels}"
315
- ]
316
- },
317
- {
318
- "cell_type": "code",
319
- "execution_count": null,
320
- "id": "11e0c254-ebb1-4100-a389-9e661d0810ff",
321
- "metadata": {},
322
- "outputs": [],
323
- "source": [
324
- "model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))\n",
325
- "model.config.id2label = id2label\n",
326
- "model.config.label2id = label2id\n",
327
- "\n",
328
- "print(model.num_parameters(only_trainable=True) / 1e6)"
329
- ]
330
- },
331
- {
332
- "cell_type": "code",
333
- "execution_count": null,
334
- "id": "bea51959-9abc-4afc-aee6-0e774f8db9c2",
335
- "metadata": {},
336
- "outputs": [],
337
- "source": [
338
- "accuracy = evaluate.load(\"accuracy\")\n",
339
- "\n",
340
- "def compute_metrics(eval_pred):\n",
341
- " predictions = eval_pred.predictions\n",
342
- " label_ids = eval_pred.label_ids\n",
343
- "\n",
344
- " predicted_labels = predictions.argmax(axis=1)\n",
345
- " acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']\n",
346
- " \n",
347
- " return {\n",
348
- " \"accuracy\": acc_score\n",
349
- " }"
350
- ]
351
- },
352
- {
353
- "cell_type": "code",
354
- "execution_count": null,
355
- "id": "d5ea0bbc-51a3-4b98-823e-10819ffda292",
356
- "metadata": {},
357
- "outputs": [],
358
- "source": [
359
- "args = TrainingArguments(\n",
360
- " output_dir=\"deepfake_vit\",\n",
361
- " logging_dir='./logs',\n",
362
- " evaluation_strategy=\"epoch\",\n",
363
- " learning_rate=2e-5,\n",
364
- " per_device_train_batch_size=32,\n",
365
- " per_device_eval_batch_size=8,\n",
366
- " num_train_epochs=4,\n",
367
- " weight_decay=0.02,\n",
368
- " warmup_steps=50,\n",
369
- " remove_unused_columns=False,\n",
370
- " save_strategy='epoch',\n",
371
- " load_best_model_at_end=True,\n",
372
- " save_total_limit=1,\n",
373
- " report_to=\"none\"\n",
374
- ")"
375
- ]
376
- },
377
- {
378
- "cell_type": "code",
379
- "execution_count": null,
380
- "id": "0a965131-c670-43b1-a153-c1a4df611189",
381
- "metadata": {},
382
- "outputs": [],
383
- "source": [
384
- "trainer = Trainer(\n",
385
- " model,\n",
386
- " args,\n",
387
- " train_dataset=train_data,\n",
388
- " eval_dataset=test_data,\n",
389
- " data_collator=collate_fn,\n",
390
- " compute_metrics=compute_metrics,\n",
391
- " tokenizer=processor,\n",
392
- ")"
393
- ]
394
- },
395
- {
396
- "cell_type": "code",
397
- "execution_count": null,
398
- "id": "ad42ea98-86d6-420e-befe-2ef77eadd76d",
399
- "metadata": {},
400
- "outputs": [],
401
- "source": [
402
- "trainer.evaluate()"
403
- ]
404
- },
405
- {
406
- "cell_type": "code",
407
- "execution_count": null,
408
- "id": "df43c341-0e55-41ef-a274-731c88b9b5d5",
409
- "metadata": {},
410
- "outputs": [],
411
- "source": [
412
- "trainer.train()"
413
- ]
414
- },
415
- {
416
- "cell_type": "code",
417
- "execution_count": null,
418
- "id": "28866dda",
419
- "metadata": {},
420
- "outputs": [],
421
- "source": [
422
- "trainer.evaluate()"
423
- ]
424
- },
425
- {
426
- "cell_type": "code",
427
- "execution_count": null,
428
- "id": "0ec258d9",
429
- "metadata": {},
430
- "outputs": [],
431
- "source": [
432
- "outputs = trainer.predict(test_data)\n",
433
- "print(outputs.metrics)"
434
- ]
435
- },
436
- {
437
- "cell_type": "code",
438
- "execution_count": null,
439
- "id": "c12a6b10",
440
- "metadata": {},
441
- "outputs": [],
442
- "source": [
443
- "y_true = outputs.label_ids\n",
444
- "y_pred = outputs.predictions.argmax(1)\n",
445
- "\n",
446
- "def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues, figsize=(10, 8)):\n",
447
- " \n",
448
- " plt.figure(figsize=figsize)\n",
449
- "\n",
450
- " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
451
- " plt.title(title)\n",
452
- " plt.colorbar()\n",
453
- "\n",
454
- " tick_marks = np.arange(len(classes))\n",
455
- " plt.xticks(tick_marks, classes, rotation=90)\n",
456
- " plt.yticks(tick_marks, classes)\n",
457
- "\n",
458
- " fmt = '.0f'\n",
459
- " thresh = cm.max() / 2.0\n",
460
- " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
461
- " plt.text(j, i, format(cm[i, j], fmt), horizontalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n",
462
- "\n",
463
- " plt.ylabel('True label')\n",
464
- " plt.xlabel('Predicted label')\n",
465
- " plt.tight_layout()\n",
466
- " plt.show()\n",
467
- "\n",
468
- "accuracy = accuracy_score(y_true, y_pred)\n",
469
- "f1 = f1_score(y_true, y_pred, average='macro')\n",
470
- "\n",
471
- "print(f\"Accuracy: {accuracy:.4f}\")\n",
472
- "print(f\"F1 Score: {f1:.4f}\")\n",
473
- "\n",
474
- "if len(labels_list) <= 150:\n",
475
- " cm = confusion_matrix(y_true, y_pred)\n",
476
- " plot_confusion_matrix(cm, labels_list, figsize=(8, 6))\n",
477
- "\n",
478
- "print()\n",
479
- "print(\"Classification report:\")\n",
480
- "print()\n",
481
- "print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))"
482
- ]
483
- },
484
- {
485
- "cell_type": "code",
486
- "execution_count": null,
487
- "id": "9889438c",
488
- "metadata": {},
489
- "outputs": [],
490
- "source": [
491
- "trainer.save_model()"
492
- ]
493
- },
494
- {
495
- "cell_type": "code",
496
- "execution_count": null,
497
- "id": "688e3d62",
498
- "metadata": {},
499
- "outputs": [],
500
- "source": [
501
- "#upload to hub\n",
502
- "from huggingface_hub import notebook_login\n",
503
- "notebook_login()"
504
- ]
505
- },
506
- {
507
- "cell_type": "code",
508
- "execution_count": null,
509
- "id": "fad56df2",
510
- "metadata": {},
511
- "outputs": [],
512
- "source": [
513
- "from huggingface_hub import HfApi\n",
514
- "\n",
515
- "api = HfApi()\n",
516
- "repo_id = f\"prithivMLmods/deepfake_vit\"\n",
517
- "\n",
518
- "try:\n",
519
- " api.create_repo(repo_id)\n",
520
- " print(f\"Repo {repo_id} created\")\n",
521
- "\n",
522
- "except:\n",
523
- " \n",
524
- " print(f\"Repo {repo_id} already exists\")"
525
- ]
526
- },
527
- {
528
- "cell_type": "code",
529
- "execution_count": null,
530
- "id": "f5e1559f",
531
- "metadata": {},
532
- "outputs": [],
533
- "source": [
534
- "api.upload_folder(\n",
535
- " folder_path=\"deepfake_vit\", \n",
536
- " path_in_repo=\".\", \n",
537
- " repo_id=repo_id, \n",
538
- " repo_type=\"model\", \n",
539
- " revision=\"main\"\n",
540
- ")"
541
- ]
542
- }
543
- ],
544
- "metadata": {
545
- "kernelspec": {
546
- "display_name": "Python 3",
547
- "language": "python",
548
- "name": "python3"
549
- },
550
- "language_info": {
551
- "codemirror_mode": {
552
- "name": "ipython",
553
- "version": 3
554
- },
555
- "file_extension": ".py",
556
- "mimetype": "text/x-python",
557
- "name": "python",
558
- "nbconvert_exporter": "python",
559
- "pygments_lexer": "ipython3",
560
- "version": "3.12.7"
561
- }
562
- },
563
- "nbformat": 4,
564
- "nbformat_minor": 5
565
- }