Arunisto commited on
Commit
fbf424b
·
verified ·
1 Parent(s): adcde38

Upload UsingApretrainedSwinTransformerModelForImageClassification.ipynb

Browse files
UsingApretrainedSwinTransformerModelForImageClassification.ipynb ADDED
@@ -0,0 +1,1425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ },
15
+ "widgets": {
16
+ "application/vnd.jupyter.widget-state+json": {
17
+ "dc7ab4859abf4fdaa598729ec421129a": {
18
+ "model_module": "@jupyter-widgets/controls",
19
+ "model_name": "HBoxModel",
20
+ "model_module_version": "1.5.0",
21
+ "state": {
22
+ "_dom_classes": [],
23
+ "_model_module": "@jupyter-widgets/controls",
24
+ "_model_module_version": "1.5.0",
25
+ "_model_name": "HBoxModel",
26
+ "_view_count": null,
27
+ "_view_module": "@jupyter-widgets/controls",
28
+ "_view_module_version": "1.5.0",
29
+ "_view_name": "HBoxView",
30
+ "box_style": "",
31
+ "children": [
32
+ "IPY_MODEL_5441781ef8f046dd979a8c389a8e7587",
33
+ "IPY_MODEL_999c37c9cce448a792df58915545135f",
34
+ "IPY_MODEL_db66afed001a4324bdd44f51efeaafa1"
35
+ ],
36
+ "layout": "IPY_MODEL_11cb47a0455044af83969ae1994af22f"
37
+ }
38
+ },
39
+ "5441781ef8f046dd979a8c389a8e7587": {
40
+ "model_module": "@jupyter-widgets/controls",
41
+ "model_name": "HTMLModel",
42
+ "model_module_version": "1.5.0",
43
+ "state": {
44
+ "_dom_classes": [],
45
+ "_model_module": "@jupyter-widgets/controls",
46
+ "_model_module_version": "1.5.0",
47
+ "_model_name": "HTMLModel",
48
+ "_view_count": null,
49
+ "_view_module": "@jupyter-widgets/controls",
50
+ "_view_module_version": "1.5.0",
51
+ "_view_name": "HTMLView",
52
+ "description": "",
53
+ "description_tooltip": null,
54
+ "layout": "IPY_MODEL_f884e42910e14b5aa9dc9919558cf75f",
55
+ "placeholder": "​",
56
+ "style": "IPY_MODEL_53cc94f964d54d71923bc295d2882216",
57
+ "value": "config.json: 100%"
58
+ }
59
+ },
60
+ "999c37c9cce448a792df58915545135f": {
61
+ "model_module": "@jupyter-widgets/controls",
62
+ "model_name": "FloatProgressModel",
63
+ "model_module_version": "1.5.0",
64
+ "state": {
65
+ "_dom_classes": [],
66
+ "_model_module": "@jupyter-widgets/controls",
67
+ "_model_module_version": "1.5.0",
68
+ "_model_name": "FloatProgressModel",
69
+ "_view_count": null,
70
+ "_view_module": "@jupyter-widgets/controls",
71
+ "_view_module_version": "1.5.0",
72
+ "_view_name": "ProgressView",
73
+ "bar_style": "success",
74
+ "description": "",
75
+ "description_tooltip": null,
76
+ "layout": "IPY_MODEL_b0a3d7e1d2ea4e3bb7e3264c8089b26d",
77
+ "max": 71813,
78
+ "min": 0,
79
+ "orientation": "horizontal",
80
+ "style": "IPY_MODEL_279f9349c7334713bb760bbc4bff3d3c",
81
+ "value": 71813
82
+ }
83
+ },
84
+ "db66afed001a4324bdd44f51efeaafa1": {
85
+ "model_module": "@jupyter-widgets/controls",
86
+ "model_name": "HTMLModel",
87
+ "model_module_version": "1.5.0",
88
+ "state": {
89
+ "_dom_classes": [],
90
+ "_model_module": "@jupyter-widgets/controls",
91
+ "_model_module_version": "1.5.0",
92
+ "_model_name": "HTMLModel",
93
+ "_view_count": null,
94
+ "_view_module": "@jupyter-widgets/controls",
95
+ "_view_module_version": "1.5.0",
96
+ "_view_name": "HTMLView",
97
+ "description": "",
98
+ "description_tooltip": null,
99
+ "layout": "IPY_MODEL_865cd86f5627433e9d92c1de46917f08",
100
+ "placeholder": "​",
101
+ "style": "IPY_MODEL_73db46b258794cfe922424d06689e44e",
102
+ "value": " 71.8k/71.8k [00:00<00:00, 3.22MB/s]"
103
+ }
104
+ },
105
+ "11cb47a0455044af83969ae1994af22f": {
106
+ "model_module": "@jupyter-widgets/base",
107
+ "model_name": "LayoutModel",
108
+ "model_module_version": "1.2.0",
109
+ "state": {
110
+ "_model_module": "@jupyter-widgets/base",
111
+ "_model_module_version": "1.2.0",
112
+ "_model_name": "LayoutModel",
113
+ "_view_count": null,
114
+ "_view_module": "@jupyter-widgets/base",
115
+ "_view_module_version": "1.2.0",
116
+ "_view_name": "LayoutView",
117
+ "align_content": null,
118
+ "align_items": null,
119
+ "align_self": null,
120
+ "border": null,
121
+ "bottom": null,
122
+ "display": null,
123
+ "flex": null,
124
+ "flex_flow": null,
125
+ "grid_area": null,
126
+ "grid_auto_columns": null,
127
+ "grid_auto_flow": null,
128
+ "grid_auto_rows": null,
129
+ "grid_column": null,
130
+ "grid_gap": null,
131
+ "grid_row": null,
132
+ "grid_template_areas": null,
133
+ "grid_template_columns": null,
134
+ "grid_template_rows": null,
135
+ "height": null,
136
+ "justify_content": null,
137
+ "justify_items": null,
138
+ "left": null,
139
+ "margin": null,
140
+ "max_height": null,
141
+ "max_width": null,
142
+ "min_height": null,
143
+ "min_width": null,
144
+ "object_fit": null,
145
+ "object_position": null,
146
+ "order": null,
147
+ "overflow": null,
148
+ "overflow_x": null,
149
+ "overflow_y": null,
150
+ "padding": null,
151
+ "right": null,
152
+ "top": null,
153
+ "visibility": null,
154
+ "width": null
155
+ }
156
+ },
157
+ "f884e42910e14b5aa9dc9919558cf75f": {
158
+ "model_module": "@jupyter-widgets/base",
159
+ "model_name": "LayoutModel",
160
+ "model_module_version": "1.2.0",
161
+ "state": {
162
+ "_model_module": "@jupyter-widgets/base",
163
+ "_model_module_version": "1.2.0",
164
+ "_model_name": "LayoutModel",
165
+ "_view_count": null,
166
+ "_view_module": "@jupyter-widgets/base",
167
+ "_view_module_version": "1.2.0",
168
+ "_view_name": "LayoutView",
169
+ "align_content": null,
170
+ "align_items": null,
171
+ "align_self": null,
172
+ "border": null,
173
+ "bottom": null,
174
+ "display": null,
175
+ "flex": null,
176
+ "flex_flow": null,
177
+ "grid_area": null,
178
+ "grid_auto_columns": null,
179
+ "grid_auto_flow": null,
180
+ "grid_auto_rows": null,
181
+ "grid_column": null,
182
+ "grid_gap": null,
183
+ "grid_row": null,
184
+ "grid_template_areas": null,
185
+ "grid_template_columns": null,
186
+ "grid_template_rows": null,
187
+ "height": null,
188
+ "justify_content": null,
189
+ "justify_items": null,
190
+ "left": null,
191
+ "margin": null,
192
+ "max_height": null,
193
+ "max_width": null,
194
+ "min_height": null,
195
+ "min_width": null,
196
+ "object_fit": null,
197
+ "object_position": null,
198
+ "order": null,
199
+ "overflow": null,
200
+ "overflow_x": null,
201
+ "overflow_y": null,
202
+ "padding": null,
203
+ "right": null,
204
+ "top": null,
205
+ "visibility": null,
206
+ "width": null
207
+ }
208
+ },
209
+ "53cc94f964d54d71923bc295d2882216": {
210
+ "model_module": "@jupyter-widgets/controls",
211
+ "model_name": "DescriptionStyleModel",
212
+ "model_module_version": "1.5.0",
213
+ "state": {
214
+ "_model_module": "@jupyter-widgets/controls",
215
+ "_model_module_version": "1.5.0",
216
+ "_model_name": "DescriptionStyleModel",
217
+ "_view_count": null,
218
+ "_view_module": "@jupyter-widgets/base",
219
+ "_view_module_version": "1.2.0",
220
+ "_view_name": "StyleView",
221
+ "description_width": ""
222
+ }
223
+ },
224
+ "b0a3d7e1d2ea4e3bb7e3264c8089b26d": {
225
+ "model_module": "@jupyter-widgets/base",
226
+ "model_name": "LayoutModel",
227
+ "model_module_version": "1.2.0",
228
+ "state": {
229
+ "_model_module": "@jupyter-widgets/base",
230
+ "_model_module_version": "1.2.0",
231
+ "_model_name": "LayoutModel",
232
+ "_view_count": null,
233
+ "_view_module": "@jupyter-widgets/base",
234
+ "_view_module_version": "1.2.0",
235
+ "_view_name": "LayoutView",
236
+ "align_content": null,
237
+ "align_items": null,
238
+ "align_self": null,
239
+ "border": null,
240
+ "bottom": null,
241
+ "display": null,
242
+ "flex": null,
243
+ "flex_flow": null,
244
+ "grid_area": null,
245
+ "grid_auto_columns": null,
246
+ "grid_auto_flow": null,
247
+ "grid_auto_rows": null,
248
+ "grid_column": null,
249
+ "grid_gap": null,
250
+ "grid_row": null,
251
+ "grid_template_areas": null,
252
+ "grid_template_columns": null,
253
+ "grid_template_rows": null,
254
+ "height": null,
255
+ "justify_content": null,
256
+ "justify_items": null,
257
+ "left": null,
258
+ "margin": null,
259
+ "max_height": null,
260
+ "max_width": null,
261
+ "min_height": null,
262
+ "min_width": null,
263
+ "object_fit": null,
264
+ "object_position": null,
265
+ "order": null,
266
+ "overflow": null,
267
+ "overflow_x": null,
268
+ "overflow_y": null,
269
+ "padding": null,
270
+ "right": null,
271
+ "top": null,
272
+ "visibility": null,
273
+ "width": null
274
+ }
275
+ },
276
+ "279f9349c7334713bb760bbc4bff3d3c": {
277
+ "model_module": "@jupyter-widgets/controls",
278
+ "model_name": "ProgressStyleModel",
279
+ "model_module_version": "1.5.0",
280
+ "state": {
281
+ "_model_module": "@jupyter-widgets/controls",
282
+ "_model_module_version": "1.5.0",
283
+ "_model_name": "ProgressStyleModel",
284
+ "_view_count": null,
285
+ "_view_module": "@jupyter-widgets/base",
286
+ "_view_module_version": "1.2.0",
287
+ "_view_name": "StyleView",
288
+ "bar_color": null,
289
+ "description_width": ""
290
+ }
291
+ },
292
+ "865cd86f5627433e9d92c1de46917f08": {
293
+ "model_module": "@jupyter-widgets/base",
294
+ "model_name": "LayoutModel",
295
+ "model_module_version": "1.2.0",
296
+ "state": {
297
+ "_model_module": "@jupyter-widgets/base",
298
+ "_model_module_version": "1.2.0",
299
+ "_model_name": "LayoutModel",
300
+ "_view_count": null,
301
+ "_view_module": "@jupyter-widgets/base",
302
+ "_view_module_version": "1.2.0",
303
+ "_view_name": "LayoutView",
304
+ "align_content": null,
305
+ "align_items": null,
306
+ "align_self": null,
307
+ "border": null,
308
+ "bottom": null,
309
+ "display": null,
310
+ "flex": null,
311
+ "flex_flow": null,
312
+ "grid_area": null,
313
+ "grid_auto_columns": null,
314
+ "grid_auto_flow": null,
315
+ "grid_auto_rows": null,
316
+ "grid_column": null,
317
+ "grid_gap": null,
318
+ "grid_row": null,
319
+ "grid_template_areas": null,
320
+ "grid_template_columns": null,
321
+ "grid_template_rows": null,
322
+ "height": null,
323
+ "justify_content": null,
324
+ "justify_items": null,
325
+ "left": null,
326
+ "margin": null,
327
+ "max_height": null,
328
+ "max_width": null,
329
+ "min_height": null,
330
+ "min_width": null,
331
+ "object_fit": null,
332
+ "object_position": null,
333
+ "order": null,
334
+ "overflow": null,
335
+ "overflow_x": null,
336
+ "overflow_y": null,
337
+ "padding": null,
338
+ "right": null,
339
+ "top": null,
340
+ "visibility": null,
341
+ "width": null
342
+ }
343
+ },
344
+ "73db46b258794cfe922424d06689e44e": {
345
+ "model_module": "@jupyter-widgets/controls",
346
+ "model_name": "DescriptionStyleModel",
347
+ "model_module_version": "1.5.0",
348
+ "state": {
349
+ "_model_module": "@jupyter-widgets/controls",
350
+ "_model_module_version": "1.5.0",
351
+ "_model_name": "DescriptionStyleModel",
352
+ "_view_count": null,
353
+ "_view_module": "@jupyter-widgets/base",
354
+ "_view_module_version": "1.2.0",
355
+ "_view_name": "StyleView",
356
+ "description_width": ""
357
+ }
358
+ },
359
+ "cd77cd0c2d1741c39b9ca6f36ec562c3": {
360
+ "model_module": "@jupyter-widgets/controls",
361
+ "model_name": "HBoxModel",
362
+ "model_module_version": "1.5.0",
363
+ "state": {
364
+ "_dom_classes": [],
365
+ "_model_module": "@jupyter-widgets/controls",
366
+ "_model_module_version": "1.5.0",
367
+ "_model_name": "HBoxModel",
368
+ "_view_count": null,
369
+ "_view_module": "@jupyter-widgets/controls",
370
+ "_view_module_version": "1.5.0",
371
+ "_view_name": "HBoxView",
372
+ "box_style": "",
373
+ "children": [
374
+ "IPY_MODEL_4f28a143f3cd438b87a74a62d9c5bea4",
375
+ "IPY_MODEL_256dc2937cb2449f84ec14b9dae2e51f",
376
+ "IPY_MODEL_a60bc29c8c1044f690a6cd4f74ed7c2c"
377
+ ],
378
+ "layout": "IPY_MODEL_a33d503834a741a9a23a26c03c1c3aaf"
379
+ }
380
+ },
381
+ "4f28a143f3cd438b87a74a62d9c5bea4": {
382
+ "model_module": "@jupyter-widgets/controls",
383
+ "model_name": "HTMLModel",
384
+ "model_module_version": "1.5.0",
385
+ "state": {
386
+ "_dom_classes": [],
387
+ "_model_module": "@jupyter-widgets/controls",
388
+ "_model_module_version": "1.5.0",
389
+ "_model_name": "HTMLModel",
390
+ "_view_count": null,
391
+ "_view_module": "@jupyter-widgets/controls",
392
+ "_view_module_version": "1.5.0",
393
+ "_view_name": "HTMLView",
394
+ "description": "",
395
+ "description_tooltip": null,
396
+ "layout": "IPY_MODEL_a2f0b441f346498faabf97d5d8eabd3a",
397
+ "placeholder": "​",
398
+ "style": "IPY_MODEL_3243191638d14067a1565ae2f03f48de",
399
+ "value": "model.safetensors: 100%"
400
+ }
401
+ },
402
+ "256dc2937cb2449f84ec14b9dae2e51f": {
403
+ "model_module": "@jupyter-widgets/controls",
404
+ "model_name": "FloatProgressModel",
405
+ "model_module_version": "1.5.0",
406
+ "state": {
407
+ "_dom_classes": [],
408
+ "_model_module": "@jupyter-widgets/controls",
409
+ "_model_module_version": "1.5.0",
410
+ "_model_name": "FloatProgressModel",
411
+ "_view_count": null,
412
+ "_view_module": "@jupyter-widgets/controls",
413
+ "_view_module_version": "1.5.0",
414
+ "_view_name": "ProgressView",
415
+ "bar_style": "success",
416
+ "description": "",
417
+ "description_tooltip": null,
418
+ "layout": "IPY_MODEL_f0b74718119e4f3a9bb6295bc7567f6d",
419
+ "max": 113412768,
420
+ "min": 0,
421
+ "orientation": "horizontal",
422
+ "style": "IPY_MODEL_fd6d8576cbc444e1b22111ddf4fa1dc2",
423
+ "value": 113412768
424
+ }
425
+ },
426
+ "a60bc29c8c1044f690a6cd4f74ed7c2c": {
427
+ "model_module": "@jupyter-widgets/controls",
428
+ "model_name": "HTMLModel",
429
+ "model_module_version": "1.5.0",
430
+ "state": {
431
+ "_dom_classes": [],
432
+ "_model_module": "@jupyter-widgets/controls",
433
+ "_model_module_version": "1.5.0",
434
+ "_model_name": "HTMLModel",
435
+ "_view_count": null,
436
+ "_view_module": "@jupyter-widgets/controls",
437
+ "_view_module_version": "1.5.0",
438
+ "_view_name": "HTMLView",
439
+ "description": "",
440
+ "description_tooltip": null,
441
+ "layout": "IPY_MODEL_c469ed61c6b946b78bb0308f36cb915b",
442
+ "placeholder": "​",
443
+ "style": "IPY_MODEL_e0db11f2d42744e982e2206a1640344a",
444
+ "value": " 113M/113M [00:01<00:00, 130MB/s]"
445
+ }
446
+ },
447
+ "a33d503834a741a9a23a26c03c1c3aaf": {
448
+ "model_module": "@jupyter-widgets/base",
449
+ "model_name": "LayoutModel",
450
+ "model_module_version": "1.2.0",
451
+ "state": {
452
+ "_model_module": "@jupyter-widgets/base",
453
+ "_model_module_version": "1.2.0",
454
+ "_model_name": "LayoutModel",
455
+ "_view_count": null,
456
+ "_view_module": "@jupyter-widgets/base",
457
+ "_view_module_version": "1.2.0",
458
+ "_view_name": "LayoutView",
459
+ "align_content": null,
460
+ "align_items": null,
461
+ "align_self": null,
462
+ "border": null,
463
+ "bottom": null,
464
+ "display": null,
465
+ "flex": null,
466
+ "flex_flow": null,
467
+ "grid_area": null,
468
+ "grid_auto_columns": null,
469
+ "grid_auto_flow": null,
470
+ "grid_auto_rows": null,
471
+ "grid_column": null,
472
+ "grid_gap": null,
473
+ "grid_row": null,
474
+ "grid_template_areas": null,
475
+ "grid_template_columns": null,
476
+ "grid_template_rows": null,
477
+ "height": null,
478
+ "justify_content": null,
479
+ "justify_items": null,
480
+ "left": null,
481
+ "margin": null,
482
+ "max_height": null,
483
+ "max_width": null,
484
+ "min_height": null,
485
+ "min_width": null,
486
+ "object_fit": null,
487
+ "object_position": null,
488
+ "order": null,
489
+ "overflow": null,
490
+ "overflow_x": null,
491
+ "overflow_y": null,
492
+ "padding": null,
493
+ "right": null,
494
+ "top": null,
495
+ "visibility": null,
496
+ "width": null
497
+ }
498
+ },
499
+ "a2f0b441f346498faabf97d5d8eabd3a": {
500
+ "model_module": "@jupyter-widgets/base",
501
+ "model_name": "LayoutModel",
502
+ "model_module_version": "1.2.0",
503
+ "state": {
504
+ "_model_module": "@jupyter-widgets/base",
505
+ "_model_module_version": "1.2.0",
506
+ "_model_name": "LayoutModel",
507
+ "_view_count": null,
508
+ "_view_module": "@jupyter-widgets/base",
509
+ "_view_module_version": "1.2.0",
510
+ "_view_name": "LayoutView",
511
+ "align_content": null,
512
+ "align_items": null,
513
+ "align_self": null,
514
+ "border": null,
515
+ "bottom": null,
516
+ "display": null,
517
+ "flex": null,
518
+ "flex_flow": null,
519
+ "grid_area": null,
520
+ "grid_auto_columns": null,
521
+ "grid_auto_flow": null,
522
+ "grid_auto_rows": null,
523
+ "grid_column": null,
524
+ "grid_gap": null,
525
+ "grid_row": null,
526
+ "grid_template_areas": null,
527
+ "grid_template_columns": null,
528
+ "grid_template_rows": null,
529
+ "height": null,
530
+ "justify_content": null,
531
+ "justify_items": null,
532
+ "left": null,
533
+ "margin": null,
534
+ "max_height": null,
535
+ "max_width": null,
536
+ "min_height": null,
537
+ "min_width": null,
538
+ "object_fit": null,
539
+ "object_position": null,
540
+ "order": null,
541
+ "overflow": null,
542
+ "overflow_x": null,
543
+ "overflow_y": null,
544
+ "padding": null,
545
+ "right": null,
546
+ "top": null,
547
+ "visibility": null,
548
+ "width": null
549
+ }
550
+ },
551
+ "3243191638d14067a1565ae2f03f48de": {
552
+ "model_module": "@jupyter-widgets/controls",
553
+ "model_name": "DescriptionStyleModel",
554
+ "model_module_version": "1.5.0",
555
+ "state": {
556
+ "_model_module": "@jupyter-widgets/controls",
557
+ "_model_module_version": "1.5.0",
558
+ "_model_name": "DescriptionStyleModel",
559
+ "_view_count": null,
560
+ "_view_module": "@jupyter-widgets/base",
561
+ "_view_module_version": "1.2.0",
562
+ "_view_name": "StyleView",
563
+ "description_width": ""
564
+ }
565
+ },
566
+ "f0b74718119e4f3a9bb6295bc7567f6d": {
567
+ "model_module": "@jupyter-widgets/base",
568
+ "model_name": "LayoutModel",
569
+ "model_module_version": "1.2.0",
570
+ "state": {
571
+ "_model_module": "@jupyter-widgets/base",
572
+ "_model_module_version": "1.2.0",
573
+ "_model_name": "LayoutModel",
574
+ "_view_count": null,
575
+ "_view_module": "@jupyter-widgets/base",
576
+ "_view_module_version": "1.2.0",
577
+ "_view_name": "LayoutView",
578
+ "align_content": null,
579
+ "align_items": null,
580
+ "align_self": null,
581
+ "border": null,
582
+ "bottom": null,
583
+ "display": null,
584
+ "flex": null,
585
+ "flex_flow": null,
586
+ "grid_area": null,
587
+ "grid_auto_columns": null,
588
+ "grid_auto_flow": null,
589
+ "grid_auto_rows": null,
590
+ "grid_column": null,
591
+ "grid_gap": null,
592
+ "grid_row": null,
593
+ "grid_template_areas": null,
594
+ "grid_template_columns": null,
595
+ "grid_template_rows": null,
596
+ "height": null,
597
+ "justify_content": null,
598
+ "justify_items": null,
599
+ "left": null,
600
+ "margin": null,
601
+ "max_height": null,
602
+ "max_width": null,
603
+ "min_height": null,
604
+ "min_width": null,
605
+ "object_fit": null,
606
+ "object_position": null,
607
+ "order": null,
608
+ "overflow": null,
609
+ "overflow_x": null,
610
+ "overflow_y": null,
611
+ "padding": null,
612
+ "right": null,
613
+ "top": null,
614
+ "visibility": null,
615
+ "width": null
616
+ }
617
+ },
618
+ "fd6d8576cbc444e1b22111ddf4fa1dc2": {
619
+ "model_module": "@jupyter-widgets/controls",
620
+ "model_name": "ProgressStyleModel",
621
+ "model_module_version": "1.5.0",
622
+ "state": {
623
+ "_model_module": "@jupyter-widgets/controls",
624
+ "_model_module_version": "1.5.0",
625
+ "_model_name": "ProgressStyleModel",
626
+ "_view_count": null,
627
+ "_view_module": "@jupyter-widgets/base",
628
+ "_view_module_version": "1.2.0",
629
+ "_view_name": "StyleView",
630
+ "bar_color": null,
631
+ "description_width": ""
632
+ }
633
+ },
634
+ "c469ed61c6b946b78bb0308f36cb915b": {
635
+ "model_module": "@jupyter-widgets/base",
636
+ "model_name": "LayoutModel",
637
+ "model_module_version": "1.2.0",
638
+ "state": {
639
+ "_model_module": "@jupyter-widgets/base",
640
+ "_model_module_version": "1.2.0",
641
+ "_model_name": "LayoutModel",
642
+ "_view_count": null,
643
+ "_view_module": "@jupyter-widgets/base",
644
+ "_view_module_version": "1.2.0",
645
+ "_view_name": "LayoutView",
646
+ "align_content": null,
647
+ "align_items": null,
648
+ "align_self": null,
649
+ "border": null,
650
+ "bottom": null,
651
+ "display": null,
652
+ "flex": null,
653
+ "flex_flow": null,
654
+ "grid_area": null,
655
+ "grid_auto_columns": null,
656
+ "grid_auto_flow": null,
657
+ "grid_auto_rows": null,
658
+ "grid_column": null,
659
+ "grid_gap": null,
660
+ "grid_row": null,
661
+ "grid_template_areas": null,
662
+ "grid_template_columns": null,
663
+ "grid_template_rows": null,
664
+ "height": null,
665
+ "justify_content": null,
666
+ "justify_items": null,
667
+ "left": null,
668
+ "margin": null,
669
+ "max_height": null,
670
+ "max_width": null,
671
+ "min_height": null,
672
+ "min_width": null,
673
+ "object_fit": null,
674
+ "object_position": null,
675
+ "order": null,
676
+ "overflow": null,
677
+ "overflow_x": null,
678
+ "overflow_y": null,
679
+ "padding": null,
680
+ "right": null,
681
+ "top": null,
682
+ "visibility": null,
683
+ "width": null
684
+ }
685
+ },
686
+ "e0db11f2d42744e982e2206a1640344a": {
687
+ "model_module": "@jupyter-widgets/controls",
688
+ "model_name": "DescriptionStyleModel",
689
+ "model_module_version": "1.5.0",
690
+ "state": {
691
+ "_model_module": "@jupyter-widgets/controls",
692
+ "_model_module_version": "1.5.0",
693
+ "_model_name": "DescriptionStyleModel",
694
+ "_view_count": null,
695
+ "_view_module": "@jupyter-widgets/base",
696
+ "_view_module_version": "1.2.0",
697
+ "_view_name": "StyleView",
698
+ "description_width": ""
699
+ }
700
+ }
701
+ }
702
+ }
703
+ },
704
+ "cells": [
705
+ {
706
+ "cell_type": "markdown",
707
+ "source": [
708
+ "## using pre-trained swin transformer to train the model"
709
+ ],
710
+ "metadata": {
711
+ "id": "kuFuGmlMnSSF"
712
+ }
713
+ },
714
+ {
715
+ "cell_type": "code",
716
+ "source": [
717
+ "import torch\n",
718
+ "from torchvision import datasets, transforms\n",
719
+ "from torch.utils.data import DataLoader\n",
720
+ "\n",
721
+ "# Define transformation pipeline\n",
722
+ "transform = transforms.Compose([\n",
723
+ " transforms.Resize((224, 224)), # Resize to 224x224 for Swin Transformer\n",
724
+ " transforms.ToTensor(),\n",
725
+ "])\n",
726
+ "\n",
727
+ "# Load dataset\n",
728
+ "train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/data/brain_tumor_dataset/train', transform=transform)\n",
729
+ "val_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/data/brain_tumor_dataset/test', transform=transform)\n",
730
+ "\n",
731
+ "# Create dataloaders\n",
732
+ "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
733
+ "val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)\n"
734
+ ],
735
+ "metadata": {
736
+ "id": "dacVmRJWSB-7"
737
+ },
738
+ "execution_count": 1,
739
+ "outputs": []
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "source": [
744
+ "from transformers import SwinForImageClassification\n",
745
+ "\n",
746
+ "# Load the pre-trained Swin Transformer model with 4 output classes\n",
747
+ "model = SwinForImageClassification.from_pretrained(\n",
748
+ " 'microsoft/swin-tiny-patch4-window7-224',\n",
749
+ " num_labels=2, # Number of tumor types\n",
750
+ " ignore_mismatched_sizes=True # Ignore size mismatch for the classifier layer\n",
751
+ ")"
752
+ ],
753
+ "metadata": {
754
+ "colab": {
755
+ "base_uri": "https://localhost:8080/",
756
+ "height": 170,
757
+ "referenced_widgets": [
758
+ "dc7ab4859abf4fdaa598729ec421129a",
759
+ "5441781ef8f046dd979a8c389a8e7587",
760
+ "999c37c9cce448a792df58915545135f",
761
+ "db66afed001a4324bdd44f51efeaafa1",
762
+ "11cb47a0455044af83969ae1994af22f",
763
+ "f884e42910e14b5aa9dc9919558cf75f",
764
+ "53cc94f964d54d71923bc295d2882216",
765
+ "b0a3d7e1d2ea4e3bb7e3264c8089b26d",
766
+ "279f9349c7334713bb760bbc4bff3d3c",
767
+ "865cd86f5627433e9d92c1de46917f08",
768
+ "73db46b258794cfe922424d06689e44e",
769
+ "cd77cd0c2d1741c39b9ca6f36ec562c3",
770
+ "4f28a143f3cd438b87a74a62d9c5bea4",
771
+ "256dc2937cb2449f84ec14b9dae2e51f",
772
+ "a60bc29c8c1044f690a6cd4f74ed7c2c",
773
+ "a33d503834a741a9a23a26c03c1c3aaf",
774
+ "a2f0b441f346498faabf97d5d8eabd3a",
775
+ "3243191638d14067a1565ae2f03f48de",
776
+ "f0b74718119e4f3a9bb6295bc7567f6d",
777
+ "fd6d8576cbc444e1b22111ddf4fa1dc2",
778
+ "c469ed61c6b946b78bb0308f36cb915b",
779
+ "e0db11f2d42744e982e2206a1640344a"
780
+ ]
781
+ },
782
+ "id": "WQb3MZ2oSTLr",
783
+ "outputId": "f159d174-0444-44de-dfe0-b77a20694146"
784
+ },
785
+ "execution_count": 2,
786
+ "outputs": [
787
+ {
788
+ "output_type": "display_data",
789
+ "data": {
790
+ "text/plain": [
791
+ "config.json: 0%| | 0.00/71.8k [00:00<?, ?B/s]"
792
+ ],
793
+ "application/vnd.jupyter.widget-view+json": {
794
+ "version_major": 2,
795
+ "version_minor": 0,
796
+ "model_id": "dc7ab4859abf4fdaa598729ec421129a"
797
+ }
798
+ },
799
+ "metadata": {}
800
+ },
801
+ {
802
+ "output_type": "display_data",
803
+ "data": {
804
+ "text/plain": [
805
+ "model.safetensors: 0%| | 0.00/113M [00:00<?, ?B/s]"
806
+ ],
807
+ "application/vnd.jupyter.widget-view+json": {
808
+ "version_major": 2,
809
+ "version_minor": 0,
810
+ "model_id": "cd77cd0c2d1741c39b9ca6f36ec562c3"
811
+ }
812
+ },
813
+ "metadata": {}
814
+ },
815
+ {
816
+ "output_type": "stream",
817
+ "name": "stderr",
818
+ "text": [
819
+ "Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:\n",
820
+ "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated\n",
821
+ "- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated\n",
822
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
823
+ ]
824
+ }
825
+ ]
826
+ },
827
+ {
828
+ "cell_type": "code",
829
+ "source": [
830
+ "# Freeze all layers except the final classification head\n",
831
+ "for param in model.parameters():\n",
832
+ " param.requires_grad = False\n",
833
+ "\n",
834
+ "# Unfreeze the classification layer\n",
835
+ "for param in model.classifier.parameters():\n",
836
+ " param.requires_grad = True\n"
837
+ ],
838
+ "metadata": {
839
+ "id": "_oCiVhFbSbBQ"
840
+ },
841
+ "execution_count": 3,
842
+ "outputs": []
843
+ },
844
+ {
845
+ "cell_type": "code",
846
+ "source": [
847
+ "import torch.optim as optim\n",
848
+ "from torch.optim.lr_scheduler import StepLR\n",
849
+ "import torch.nn as nn\n",
850
+ "\n",
851
+ "# Set up optimizer and loss function\n",
852
+ "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
853
+ "criterion = nn.CrossEntropyLoss()\n",
854
+ "\n",
855
+ "# Move model to GPU if available\n",
856
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
857
+ "model.to(device)"
858
+ ],
859
+ "metadata": {
860
+ "colab": {
861
+ "base_uri": "https://localhost:8080/"
862
+ },
863
+ "id": "sfu4me_OTSJx",
864
+ "outputId": "55f7ff18-72c3-45b2-a97c-587b81c122b5"
865
+ },
866
+ "execution_count": 4,
867
+ "outputs": [
868
+ {
869
+ "output_type": "execute_result",
870
+ "data": {
871
+ "text/plain": [
872
+ "SwinForImageClassification(\n",
873
+ " (swin): SwinModel(\n",
874
+ " (embeddings): SwinEmbeddings(\n",
875
+ " (patch_embeddings): SwinPatchEmbeddings(\n",
876
+ " (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))\n",
877
+ " )\n",
878
+ " (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
879
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
880
+ " )\n",
881
+ " (encoder): SwinEncoder(\n",
882
+ " (layers): ModuleList(\n",
883
+ " (0): SwinStage(\n",
884
+ " (blocks): ModuleList(\n",
885
+ " (0-1): 2 x SwinLayer(\n",
886
+ " (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
887
+ " (attention): SwinAttention(\n",
888
+ " (self): SwinSelfAttention(\n",
889
+ " (query): Linear(in_features=96, out_features=96, bias=True)\n",
890
+ " (key): Linear(in_features=96, out_features=96, bias=True)\n",
891
+ " (value): Linear(in_features=96, out_features=96, bias=True)\n",
892
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
893
+ " )\n",
894
+ " (output): SwinSelfOutput(\n",
895
+ " (dense): Linear(in_features=96, out_features=96, bias=True)\n",
896
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
897
+ " )\n",
898
+ " )\n",
899
+ " (drop_path): SwinDropPath(p=0.1)\n",
900
+ " (layernorm_after): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
901
+ " (intermediate): SwinIntermediate(\n",
902
+ " (dense): Linear(in_features=96, out_features=384, bias=True)\n",
903
+ " (intermediate_act_fn): GELUActivation()\n",
904
+ " )\n",
905
+ " (output): SwinOutput(\n",
906
+ " (dense): Linear(in_features=384, out_features=96, bias=True)\n",
907
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
908
+ " )\n",
909
+ " )\n",
910
+ " )\n",
911
+ " (downsample): SwinPatchMerging(\n",
912
+ " (reduction): Linear(in_features=384, out_features=192, bias=False)\n",
913
+ " (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
914
+ " )\n",
915
+ " )\n",
916
+ " (1): SwinStage(\n",
917
+ " (blocks): ModuleList(\n",
918
+ " (0-1): 2 x SwinLayer(\n",
919
+ " (layernorm_before): LayerNorm((192,), eps=1e-05, elementwise_affine=True)\n",
920
+ " (attention): SwinAttention(\n",
921
+ " (self): SwinSelfAttention(\n",
922
+ " (query): Linear(in_features=192, out_features=192, bias=True)\n",
923
+ " (key): Linear(in_features=192, out_features=192, bias=True)\n",
924
+ " (value): Linear(in_features=192, out_features=192, bias=True)\n",
925
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
926
+ " )\n",
927
+ " (output): SwinSelfOutput(\n",
928
+ " (dense): Linear(in_features=192, out_features=192, bias=True)\n",
929
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
930
+ " )\n",
931
+ " )\n",
932
+ " (drop_path): SwinDropPath(p=0.1)\n",
933
+ " (layernorm_after): LayerNorm((192,), eps=1e-05, elementwise_affine=True)\n",
934
+ " (intermediate): SwinIntermediate(\n",
935
+ " (dense): Linear(in_features=192, out_features=768, bias=True)\n",
936
+ " (intermediate_act_fn): GELUActivation()\n",
937
+ " )\n",
938
+ " (output): SwinOutput(\n",
939
+ " (dense): Linear(in_features=768, out_features=192, bias=True)\n",
940
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
941
+ " )\n",
942
+ " )\n",
943
+ " )\n",
944
+ " (downsample): SwinPatchMerging(\n",
945
+ " (reduction): Linear(in_features=768, out_features=384, bias=False)\n",
946
+ " (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
947
+ " )\n",
948
+ " )\n",
949
+ " (2): SwinStage(\n",
950
+ " (blocks): ModuleList(\n",
951
+ " (0-5): 6 x SwinLayer(\n",
952
+ " (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
953
+ " (attention): SwinAttention(\n",
954
+ " (self): SwinSelfAttention(\n",
955
+ " (query): Linear(in_features=384, out_features=384, bias=True)\n",
956
+ " (key): Linear(in_features=384, out_features=384, bias=True)\n",
957
+ " (value): Linear(in_features=384, out_features=384, bias=True)\n",
958
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
959
+ " )\n",
960
+ " (output): SwinSelfOutput(\n",
961
+ " (dense): Linear(in_features=384, out_features=384, bias=True)\n",
962
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
963
+ " )\n",
964
+ " )\n",
965
+ " (drop_path): SwinDropPath(p=0.1)\n",
966
+ " (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
967
+ " (intermediate): SwinIntermediate(\n",
968
+ " (dense): Linear(in_features=384, out_features=1536, bias=True)\n",
969
+ " (intermediate_act_fn): GELUActivation()\n",
970
+ " )\n",
971
+ " (output): SwinOutput(\n",
972
+ " (dense): Linear(in_features=1536, out_features=384, bias=True)\n",
973
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
974
+ " )\n",
975
+ " )\n",
976
+ " )\n",
977
+ " (downsample): SwinPatchMerging(\n",
978
+ " (reduction): Linear(in_features=1536, out_features=768, bias=False)\n",
979
+ " (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n",
980
+ " )\n",
981
+ " )\n",
982
+ " (3): SwinStage(\n",
983
+ " (blocks): ModuleList(\n",
984
+ " (0-1): 2 x SwinLayer(\n",
985
+ " (layernorm_before): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
986
+ " (attention): SwinAttention(\n",
987
+ " (self): SwinSelfAttention(\n",
988
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
989
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
990
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
991
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
992
+ " )\n",
993
+ " (output): SwinSelfOutput(\n",
994
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
995
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
996
+ " )\n",
997
+ " )\n",
998
+ " (drop_path): SwinDropPath(p=0.1)\n",
999
+ " (layernorm_after): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
1000
+ " (intermediate): SwinIntermediate(\n",
1001
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
1002
+ " (intermediate_act_fn): GELUActivation()\n",
1003
+ " )\n",
1004
+ " (output): SwinOutput(\n",
1005
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
1006
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
1007
+ " )\n",
1008
+ " )\n",
1009
+ " )\n",
1010
+ " )\n",
1011
+ " )\n",
1012
+ " )\n",
1013
+ " (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
1014
+ " (pooler): AdaptiveAvgPool1d(output_size=1)\n",
1015
+ " )\n",
1016
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
1017
+ ")"
1018
+ ]
1019
+ },
1020
+ "metadata": {},
1021
+ "execution_count": 4
1022
+ }
1023
+ ]
1024
+ },
1025
+ {
1026
+ "cell_type": "code",
1027
+ "source": [
1028
+ "# Training loop\n",
1029
+ "for epoch in range(10): # Train for 10 epochs (adjust as needed)\n",
1030
+ " model.train()\n",
1031
+ " running_loss = 0.0\n",
1032
+ " correct = 0\n",
1033
+ " total = 0\n",
1034
+ "\n",
1035
+ " for inputs, labels in train_loader:\n",
1036
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
1037
+ "\n",
1038
+ " # Zero the parameter gradients\n",
1039
+ " optimizer.zero_grad()\n",
1040
+ "\n",
1041
+ " # Forward pass\n",
1042
+ " outputs = model(inputs).logits\n",
1043
+ " loss = criterion(outputs, labels)\n",
1044
+ "\n",
1045
+ " # Backward pass and optimization\n",
1046
+ " loss.backward()\n",
1047
+ " optimizer.step()\n",
1048
+ "\n",
1049
+ " # Calculate accuracy\n",
1050
+ " _, predicted = torch.max(outputs, 1)\n",
1051
+ " total += labels.size(0)\n",
1052
+ " correct += (predicted == labels).sum().item()\n",
1053
+ " running_loss += loss.item()\n",
1054
+ "\n",
1055
+ " # Print training stats\n",
1056
+ " print(f'Epoch [{epoch+1}/10], Loss: {running_loss/len(train_loader)}, Accuracy: {100 * correct / total}%')\n",
1057
+ "\n",
1058
+ " # Validation\n",
1059
+ " model.eval()\n",
1060
+ " val_correct = 0\n",
1061
+ " val_total = 0\n",
1062
+ " with torch.no_grad():\n",
1063
+ " for inputs, labels in val_loader:\n",
1064
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
1065
+ " outputs = model(inputs).logits\n",
1066
+ " _, predicted = torch.max(outputs, 1)\n",
1067
+ " val_total += labels.size(0)\n",
1068
+ " val_correct += (predicted == labels).sum().item()\n",
1069
+ "\n",
1070
+ " print(f'Validation Accuracy: {100 * val_correct / val_total}%')"
1071
+ ],
1072
+ "metadata": {
1073
+ "colab": {
1074
+ "base_uri": "https://localhost:8080/"
1075
+ },
1076
+ "id": "2Q3sxrxtTXzX",
1077
+ "outputId": "c5a6b813-c5a4-43b8-cad5-48a146034ece"
1078
+ },
1079
+ "execution_count": 5,
1080
+ "outputs": [
1081
+ {
1082
+ "output_type": "stream",
1083
+ "name": "stdout",
1084
+ "text": [
1085
+ "Epoch [1/10], Loss: 0.6803878396749496, Accuracy: 58.13692480359147%\n",
1086
+ "Validation Accuracy: 59.060402684563755%\n",
1087
+ "Epoch [2/10], Loss: 0.6379124437059674, Accuracy: 72.61503928170595%\n",
1088
+ "Validation Accuracy: 69.12751677852349%\n",
1089
+ "Epoch [3/10], Loss: 0.6045689774411065, Accuracy: 77.32884399551067%\n",
1090
+ "Validation Accuracy: 74.49664429530202%\n",
1091
+ "Epoch [4/10], Loss: 0.5734582436936242, Accuracy: 79.12457912457913%\n",
1092
+ "Validation Accuracy: 78.52348993288591%\n",
1093
+ "Epoch [5/10], Loss: 0.5508207274334771, Accuracy: 80.13468013468014%\n",
1094
+ "Validation Accuracy: 80.53691275167785%\n",
1095
+ "Epoch [6/10], Loss: 0.5296014377049038, Accuracy: 80.69584736251403%\n",
1096
+ "Validation Accuracy: 78.52348993288591%\n",
1097
+ "Epoch [7/10], Loss: 0.5103116855025291, Accuracy: 82.37934904601572%\n",
1098
+ "Validation Accuracy: 79.86577181208054%\n",
1099
+ "Epoch [8/10], Loss: 0.48474655938999994, Accuracy: 83.72615039281706%\n",
1100
+ "Validation Accuracy: 77.85234899328859%\n",
1101
+ "Epoch [9/10], Loss: 0.48020742727177484, Accuracy: 83.16498316498317%\n",
1102
+ "Validation Accuracy: 79.19463087248322%\n",
1103
+ "Epoch [10/10], Loss: 0.458157547882625, Accuracy: 84.51178451178451%\n",
1104
+ "Validation Accuracy: 79.86577181208054%\n"
1105
+ ]
1106
+ }
1107
+ ]
1108
+ },
1109
+ {
1110
+ "cell_type": "code",
1111
+ "source": [
1112
+ "# Test the model\n",
1113
+ "model.eval()\n",
1114
+ "test_correct = 0\n",
1115
+ "test_total = 0\n",
1116
+ "\n",
1117
+ "with torch.no_grad():\n",
1118
+ " for inputs, labels in val_loader:\n",
1119
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
1120
+ " outputs = model(inputs).logits\n",
1121
+ " _, predicted = torch.max(outputs, 1)\n",
1122
+ " test_total += labels.size(0)\n",
1123
+ " test_correct += (predicted == labels).sum().item()\n",
1124
+ "\n",
1125
+ "print(f'Test Accuracy: {100 * test_correct / test_total}%')"
1126
+ ],
1127
+ "metadata": {
1128
+ "colab": {
1129
+ "base_uri": "https://localhost:8080/"
1130
+ },
1131
+ "id": "la_3QpX9TfiP",
1132
+ "outputId": "3a1f21ca-81a9-4cd8-9d66-ec899d4e0dea"
1133
+ },
1134
+ "execution_count": 6,
1135
+ "outputs": [
1136
+ {
1137
+ "output_type": "stream",
1138
+ "name": "stdout",
1139
+ "text": [
1140
+ "Test Accuracy: 79.86577181208054%\n"
1141
+ ]
1142
+ }
1143
+ ]
1144
+ },
1145
+ {
1146
+ "cell_type": "code",
1147
+ "source": [
1148
+ "torch.save(model.state_dict(), 'swin_brain_tumor_classifier.pth')"
1149
+ ],
1150
+ "metadata": {
1151
+ "id": "RYS7IhD2T1-a"
1152
+ },
1153
+ "execution_count": 7,
1154
+ "outputs": []
1155
+ },
1156
+ {
1157
+ "cell_type": "code",
1158
+ "source": [
1159
+ "from PIL import Image"
1160
+ ],
1161
+ "metadata": {
1162
+ "id": "f-mJ9OWQUYdB"
1163
+ },
1164
+ "execution_count": 8,
1165
+ "outputs": []
1166
+ },
1167
+ {
1168
+ "cell_type": "code",
1169
+ "source": [
1170
+ "# Load the saved model\n",
1171
+ "model.load_state_dict(torch.load('swin_brain_tumor_classifier.pth'))\n",
1172
+ "model.eval()\n",
1173
+ "\n",
1174
+ "# Make predictions on new data\n",
1175
+ "img = Image.open('/content/drive/MyDrive/data/brain_tumor_dataset/test/healthy/0566.jpg')\n",
1176
+ "img = transform(img).unsqueeze(0).to(device)\n",
1177
+ "\n",
1178
+ "# Predict\n",
1179
+ "output = model(img).logits\n",
1180
+ "_, predicted = torch.max(output, 1)\n",
1181
+ "print(f'Predicted class: {predicted.item()}')"
1182
+ ],
1183
+ "metadata": {
1184
+ "colab": {
1185
+ "base_uri": "https://localhost:8080/"
1186
+ },
1187
+ "id": "Eq0_X__wT3RV",
1188
+ "outputId": "6bdfa0ef-cba7-448b-8a5f-ef46f93b8677"
1189
+ },
1190
+ "execution_count": 20,
1191
+ "outputs": [
1192
+ {
1193
+ "output_type": "stream",
1194
+ "name": "stderr",
1195
+ "text": [
1196
+ "<ipython-input-20-c22c92cd5d87>:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
1197
+ " model.load_state_dict(torch.load('swin_brain_tumor_classifier.pth'))\n"
1198
+ ]
1199
+ },
1200
+ {
1201
+ "output_type": "stream",
1202
+ "name": "stdout",
1203
+ "text": [
1204
+ "Predicted class: 0\n"
1205
+ ]
1206
+ }
1207
+ ]
1208
+ },
1209
+ {
1210
+ "cell_type": "code",
1211
+ "source": [
1212
+ "#healthy tumor\n",
1213
+ "path = '/content/drive/MyDrive/data/brain_tumor_dataset/test/healthy'\n"
1214
+ ],
1215
+ "metadata": {
1216
+ "id": "gzK2C_4-k8F1"
1217
+ },
1218
+ "execution_count": 21,
1219
+ "outputs": []
1220
+ },
1221
+ {
1222
+ "cell_type": "code",
1223
+ "source": [
1224
+ "import os"
1225
+ ],
1226
+ "metadata": {
1227
+ "id": "6EDJd38AxZFv"
1228
+ },
1229
+ "execution_count": 22,
1230
+ "outputs": []
1231
+ },
1232
+ {
1233
+ "cell_type": "code",
1234
+ "source": [
1235
+ "files = os.listdir(path)\n",
1236
+ "\n",
1237
+ "for f in files:\n",
1238
+ " try:\n",
1239
+ " img = Image.open(os.path.join(path,f))\n",
1240
+ " img = transform(img).unsqueeze(0).to(device)\n",
1241
+ " output = model(img).logits\n",
1242
+ " _, predicted = torch.max(output, 1)\n",
1243
+ " print(f'predicted class: {predicted.item()} filename: {f} actual class: 0')\n",
1244
+ " except Exception as e:\n",
1245
+ " print(e)\n",
1246
+ " continue"
1247
+ ],
1248
+ "metadata": {
1249
+ "colab": {
1250
+ "base_uri": "https://localhost:8080/"
1251
+ },
1252
+ "id": "1jTDGblkxdAn",
1253
+ "outputId": "5610b6fc-ac47-4a6b-9314-f5398ea2787d"
1254
+ },
1255
+ "execution_count": 25,
1256
+ "outputs": [
1257
+ {
1258
+ "output_type": "stream",
1259
+ "name": "stdout",
1260
+ "text": [
1261
+ "predicted class: 0 filename: 0796.jpg actual class: 0\n",
1262
+ "predicted class: 0 filename: 0676.jpg actual class: 0\n",
1263
+ "predicted class: 1 filename: 0698.jpg actual class: 0\n",
1264
+ "predicted class: 1 filename: 0601.jpg actual class: 0\n",
1265
+ "predicted class: 0 filename: 0861.jpg actual class: 0\n",
1266
+ "predicted class: 1 filename: 0615.jpg actual class: 0\n",
1267
+ "predicted class: 0 filename: 0874.jpg actual class: 0\n",
1268
+ "predicted class: 0 filename: 0820.jpg actual class: 0\n",
1269
+ "predicted class: 0 filename: 0785.jpg actual class: 0\n",
1270
+ "predicted class: 0 filename: 0792.jpg actual class: 0\n",
1271
+ "predicted class: 0 filename: 0731.jpg actual class: 0\n",
1272
+ "predicted class: 0 filename: 0762.jpg actual class: 0\n",
1273
+ "predicted class: 1 filename: 0710.jpg actual class: 0\n",
1274
+ "predicted class: 0 filename: 0858.jpg actual class: 0\n",
1275
+ "predicted class: 0 filename: 0691.jpg actual class: 0\n",
1276
+ "predicted class: 0 filename: 0791.jpg actual class: 0\n",
1277
+ "predicted class: 1 filename: 0639.jpg actual class: 0\n",
1278
+ "predicted class: 1 filename: 0596.jpg actual class: 0\n",
1279
+ "predicted class: 1 filename: 0591.jpg actual class: 0\n",
1280
+ "predicted class: 0 filename: 0730.jpg actual class: 0\n",
1281
+ "predicted class: 0 filename: 0638.jpg actual class: 0\n",
1282
+ "predicted class: 0 filename: 0566.jpg actual class: 0\n",
1283
+ "predicted class: 1 filename: 0645.jpg actual class: 0\n",
1284
+ "predicted class: 0 filename: 0565.jpg actual class: 0\n",
1285
+ "predicted class: 1 filename: 0778.jpg actual class: 0\n",
1286
+ "predicted class: 0 filename: 0697.jpg actual class: 0\n",
1287
+ "predicted class: 1 filename: 0800.jpg actual class: 0\n",
1288
+ "predicted class: 0 filename: 0857.jpg actual class: 0\n",
1289
+ "predicted class: 0 filename: 0879.jpg actual class: 0\n",
1290
+ "predicted class: 0 filename: 0765.jpg actual class: 0\n",
1291
+ "predicted class: 0 filename: 0562.jpg actual class: 0\n",
1292
+ "predicted class: 0 filename: 0719.jpg actual class: 0\n",
1293
+ "predicted class: 0 filename: 0740.jpg actual class: 0\n",
1294
+ "predicted class: 0 filename: 0607.jpg actual class: 0\n",
1295
+ "predicted class: 0 filename: 0580.jpg actual class: 0\n",
1296
+ "predicted class: 1 filename: 0839.jpg actual class: 0\n",
1297
+ "predicted class: 0 filename: 0860.jpg actual class: 0\n",
1298
+ "predicted class: 0 filename: 0718.jpg actual class: 0\n",
1299
+ "predicted class: 0 filename: 0793.jpg actual class: 0\n",
1300
+ "predicted class: 0 filename: 0881.jpg actual class: 0\n",
1301
+ "predicted class: 0 filename: 0864.jpg actual class: 0\n",
1302
+ "predicted class: 0 filename: 0696.jpg actual class: 0\n",
1303
+ "predicted class: 0 filename: 0724.jpg actual class: 0\n",
1304
+ "predicted class: 0 filename: 0703.jpg actual class: 0\n",
1305
+ "predicted class: 0 filename: 0721.jpg actual class: 0\n",
1306
+ "predicted class: 1 filename: 0652.jpg actual class: 0\n",
1307
+ "predicted class: 0 filename: 0551.jpg actual class: 0\n",
1308
+ "predicted class: 0 filename: 0720.jpg actual class: 0\n",
1309
+ "predicted class: 0 filename: 0689.jpg actual class: 0\n",
1310
+ "predicted class: 0 filename: 0795.jpg actual class: 0\n",
1311
+ "predicted class: 1 filename: 0571.jpg actual class: 0\n",
1312
+ "predicted class: 0 filename: 0640.jpg actual class: 0\n",
1313
+ "predicted class: 0 filename: 0806.jpg actual class: 0\n",
1314
+ "predicted class: 0 filename: 0761.jpg actual class: 0\n",
1315
+ "predicted class: 1 filename: 0715.jpg actual class: 0\n",
1316
+ "predicted class: 0 filename: 0884.jpg actual class: 0\n",
1317
+ "predicted class: 1 filename: 0684.jpg actual class: 0\n",
1318
+ "predicted class: 0 filename: 0846.jpg actual class: 0\n",
1319
+ "predicted class: 0 filename: 0805.jpg actual class: 0\n",
1320
+ "predicted class: 0 filename: 0872.jpg actual class: 0\n",
1321
+ "predicted class: 1 filename: 0707.jpg actual class: 0\n",
1322
+ "predicted class: 0 filename: 0868.jpg actual class: 0\n",
1323
+ "predicted class: 1 filename: 0863.jpg actual class: 0\n",
1324
+ "predicted class: 1 filename: 0871.jpg actual class: 0\n",
1325
+ "predicted class: 0 filename: 0859.jpg actual class: 0\n",
1326
+ "predicted class: 1 filename: 0769.jpg actual class: 0\n",
1327
+ "predicted class: 0 filename: 0888.jpg actual class: 0\n",
1328
+ "predicted class: 0 filename: 0733.jpg actual class: 0\n",
1329
+ "predicted class: 0 filename: 0835.jpg actual class: 0\n",
1330
+ "predicted class: 1 filename: 0841.jpg actual class: 0\n",
1331
+ "predicted class: 0 filename: 0768.jpg actual class: 0\n",
1332
+ "predicted class: 1 filename: 0878.jpg actual class: 0\n",
1333
+ "[Errno 21] Is a directory: '/content/drive/MyDrive/data/brain_tumor_dataset/test/healthy/.ipynb_checkpoints'\n"
1334
+ ]
1335
+ }
1336
+ ]
1337
+ },
1338
+ {
1339
+ "cell_type": "code",
1340
+ "source": [
1341
+ "#calculating the accuracy\n",
1342
+ "def calculate_accuracy(model, img_path, img_files, actual_class):\n",
1343
+ " total_images = len(img_files)\n",
1344
+ " predicted_ones = 0\n",
1345
+ " for i in img_files:\n",
1346
+ " try:\n",
1347
+ " img = Image.open(os.path.join(img_path,i))\n",
1348
+ " img = transform(img).unsqueeze(0).to(device)\n",
1349
+ " output = model(img).logits\n",
1350
+ " _, predicted = torch.max(output, 1)\n",
1351
+ " if int(predicted.item()) == int(actual_class):\n",
1352
+ " predicted_ones += 1\n",
1353
+ " except Exception as e:\n",
1354
+ " continue\n",
1355
+ " accuracy_score = (predicted_ones/total_images)*100\n",
1356
+ " return accuracy_score"
1357
+ ],
1358
+ "metadata": {
1359
+ "id": "j92hyOmB15e0"
1360
+ },
1361
+ "execution_count": 39,
1362
+ "outputs": []
1363
+ },
1364
+ {
1365
+ "cell_type": "code",
1366
+ "source": [
1367
+ "img_path = '/content/drive/MyDrive/data/brain_tumor_dataset/train/healthy'\n",
1368
+ "img_files = os.listdir(img_path)\n",
1369
+ "print(\"Accuracy score:\",calculate_accuracy(model, img_path, img_files, 0))"
1370
+ ],
1371
+ "metadata": {
1372
+ "colab": {
1373
+ "base_uri": "https://localhost:8080/"
1374
+ },
1375
+ "id": "mAMy1XYG5DCw",
1376
+ "outputId": "5ff76b58-7972-4b05-ea88-5758e0c17483"
1377
+ },
1378
+ "execution_count": 40,
1379
+ "outputs": [
1380
+ {
1381
+ "output_type": "stream",
1382
+ "name": "stdout",
1383
+ "text": [
1384
+ "Accuracy score: 62.02830188679245\n"
1385
+ ]
1386
+ }
1387
+ ]
1388
+ },
1389
+ {
1390
+ "cell_type": "code",
1391
+ "source": [
1392
+ "img_path = '/content/drive/MyDrive/data/brain_tumor_dataset/train/tumor'\n",
1393
+ "img_files = os.listdir(img_path)\n",
1394
+ "print(\"Accuracy score:\",calculate_accuracy(model, img_path, img_files, 1))"
1395
+ ],
1396
+ "metadata": {
1397
+ "colab": {
1398
+ "base_uri": "https://localhost:8080/"
1399
+ },
1400
+ "id": "I4y62G1K5lXF",
1401
+ "outputId": "fb1f6ebe-0862-4f62-e268-8bc42fd803f7"
1402
+ },
1403
+ "execution_count": 41,
1404
+ "outputs": [
1405
+ {
1406
+ "output_type": "stream",
1407
+ "name": "stdout",
1408
+ "text": [
1409
+ "Accuracy score: 85.65310492505354\n"
1410
+ ]
1411
+ }
1412
+ ]
1413
+ },
1414
+ {
1415
+ "cell_type": "markdown",
1416
+ "source": [
1417
+ "### For healthy class model accuracy score is 62%\n",
1418
+ "### For tumor images model accuracy score is 85%"
1419
+ ],
1420
+ "metadata": {
1421
+ "id": "oLFI0ASV7YKJ"
1422
+ }
1423
+ }
1424
+ ]
1425
+ }