Commit
·
2c8e31c
0
Parent(s):
Clean start: add all files with LFS tracking
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +41 -0
- README.md +16 -0
- app.py +16 -0
- labels.json +42 -0
- model.py +45 -0
- predict.py +91 -0
- requirements.txt +6 -0
- sample_images/airport_terminal/airport-check-in.jpg +3 -0
- sample_images/airport_terminal/checkin.jpg +3 -0
- sample_images/amphitheatre/amphitheatre.png +3 -0
- sample_images/amusement_park/A Swinger Ride.jpg +3 -0
- sample_images/amusement_park/airport-check-in.jpg +3 -0
- sample_images/art_gallery/art_gallery.jpg +3 -0
- sample_images/bakery_shop/ShopInterior.jpg +3 -0
- sample_images/bar/Ram02.jpg +3 -0
- sample_images/bookstore/book_store.jpg +3 -0
- sample_images/botanical_garden/botanical_garden.jpg +3 -0
- sample_images/bridge/Medieval_Exe_Bridge_Exeter.jpg +3 -0
- sample_images/bridge/ironbridge3.jpg +3 -0
- sample_images/bridge/millers.jpg +3 -0
- sample_images/bus_interior/29477487945_81aabab695_b.jpg +3 -0
- sample_images/bus_interior/37139142640_3807d91aea_b.jpg +3 -0
- sample_images/butchers_shop/butcher_shop.jpg +3 -0
- sample_images/campsite/camp_site.png +3 -0
- sample_images/classroom/classroom.png +3 -0
- sample_images/coffee_shop/the-terrace4.jpg +3 -0
- sample_images/construction_site/construction_site.jpeg +3 -0
- sample_images/courtyard/courtyard.jpg +3 -0
- sample_images/driveway/driveway.jpeg +3 -0
- sample_images/fire_station/firestation.jpeg +3 -0
- sample_images/fountain/fountain.jpg +3 -0
- sample_images/gas_station/gas_station.png +3 -0
- sample_images/harbour/Harbour.jpg +3 -0
- sample_images/highway/highway.png +3 -0
- sample_images/kindergarten_classroom/kindergarden_classroon.jpg +3 -0
- sample_images/lobby/lobby.jpg +3 -0
- sample_images/market_outdoor/img_7421.jpg +3 -0
- sample_images/market_outdoor/www.visitexeter.com.jpeg +3 -0
- sample_images/museum/7.jpg +3 -0
- sample_images/museum/albert-queen-1-5.jpg +3 -0
- sample_images/museum/img2295_1.jpg +3 -0
- sample_images/office/images.jpeg +3 -0
- sample_images/office/images2.jpeg +3 -0
- sample_images/parking_lot/parking_lot.png +3 -0
- sample_images/phone_booth/phone_booth.jpg +3 -0
- sample_images/playground/3146371_077d0213.jpg +3 -0
- sample_images/playground/exeter_hall2.jpg +3 -0
- sample_images/playground/planet2.jpg +3 -0
- sample_images/railroad_track/rail_road_track.jpg +3 -0
- sample_images/restaurant/Boston-Tea-Party.jpg +3 -0
.gitattributes
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
41 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "AML 16"
|
3 |
+
version: "1.0.0"
|
4 |
+
emoji: "🤗"
|
5 |
+
colorFrom: indigo
|
6 |
+
colorTo: pink
|
7 |
+
sdk: gradio
|
8 |
+
sdk_version: "5.29.0"
|
9 |
+
app_file: app.py
|
10 |
+
pinned: false
|
11 |
+
---
|
12 |
+
|
13 |
+
# AML 16
|
14 |
+
|
15 |
+
This is a Demo using Gradio app for AML 16.
|
16 |
+
|
app.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from predict import predict
|
3 |
+
|
4 |
+
demo = gr.Interface(
|
5 |
+
fn=predict,
|
6 |
+
inputs=gr.Image(type="filepath", label="Upload Image"),
|
7 |
+
outputs=[
|
8 |
+
gr.Image(label="Uploaded Image"),
|
9 |
+
gr.Image(label="Top-1 Class Example"),
|
10 |
+
gr.Label(label="Top-5 Probabilities")
|
11 |
+
],
|
12 |
+
title="Scene Classification with Reference Image",
|
13 |
+
description="Upload an image to get the predicted class with a sample image and top-5 prediction chart."
|
14 |
+
)
|
15 |
+
|
16 |
+
demo.launch()
|
labels.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"airport_terminal",
|
3 |
+
"amphitheatre",
|
4 |
+
"amusement_park",
|
5 |
+
"art_gallery",
|
6 |
+
"bakery_shop",
|
7 |
+
"bar",
|
8 |
+
"bookstore",
|
9 |
+
"botanical_garden",
|
10 |
+
"bridge",
|
11 |
+
"bus_interior",
|
12 |
+
"butchers_shop",
|
13 |
+
"campsite",
|
14 |
+
"classroom",
|
15 |
+
"coffee_shop",
|
16 |
+
"construction_site",
|
17 |
+
"courtyard",
|
18 |
+
"driveway",
|
19 |
+
"fire_station",
|
20 |
+
"fountain",
|
21 |
+
"gas_station",
|
22 |
+
"harbour",
|
23 |
+
"highway",
|
24 |
+
"kindergarten_classroom",
|
25 |
+
"lobby",
|
26 |
+
"market_outdoor",
|
27 |
+
"museum",
|
28 |
+
"office",
|
29 |
+
"parking_lot",
|
30 |
+
"phone_booth",
|
31 |
+
"playground",
|
32 |
+
"railroad_track",
|
33 |
+
"restaurant",
|
34 |
+
"river",
|
35 |
+
"shed",
|
36 |
+
"staircase",
|
37 |
+
"supermarket",
|
38 |
+
"swimming_pool_outdoor",
|
39 |
+
"track",
|
40 |
+
"valley",
|
41 |
+
"yard"
|
42 |
+
]
|
model.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import SwinForImageClassification
|
4 |
+
|
5 |
+
def quantize_model(model, mode="linear"):
|
6 |
+
model.eval().cpu()
|
7 |
+
if mode == "linear":
|
8 |
+
return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
|
9 |
+
return model
|
10 |
+
|
11 |
+
class SwinModel(nn.Module):
|
12 |
+
def __init__(self, model_name="microsoft/swin-base-patch4-window7-224", num_classes=40, from_pretrained=False):
|
13 |
+
super(SwinModel, self).__init__()
|
14 |
+
|
15 |
+
if from_pretrained:
|
16 |
+
self.model = SwinForImageClassification.from_pretrained(model_name)
|
17 |
+
else:
|
18 |
+
config = SwinForImageClassification.from_pretrained(model_name).config
|
19 |
+
config.num_labels = num_classes
|
20 |
+
self.model = SwinForImageClassification(config)
|
21 |
+
|
22 |
+
in_features = self.model.classifier.in_features
|
23 |
+
self.model.classifier = nn.Linear(in_features, num_classes)
|
24 |
+
|
25 |
+
def forward(self, images):
|
26 |
+
outputs = self.model(images)
|
27 |
+
return outputs.logits
|
28 |
+
|
29 |
+
def load_model(weights_path="best_model.pth", num_classes=40):
|
30 |
+
model = SwinModel(num_classes=num_classes, from_pretrained=False)
|
31 |
+
|
32 |
+
checkpoint = torch.load(weights_path, map_location="cpu")
|
33 |
+
|
34 |
+
if "model_state_dict" in checkpoint:
|
35 |
+
state_dict = checkpoint["model_state_dict"]
|
36 |
+
else:
|
37 |
+
state_dict = checkpoint
|
38 |
+
|
39 |
+
filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}
|
40 |
+
|
41 |
+
model.load_state_dict(filtered_state_dict, strict=False)
|
42 |
+
|
43 |
+
model = quantize_model(model, mode="linear")
|
44 |
+
model.eval()
|
45 |
+
return model
|
predict.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from PIL import Image
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
# from model import load_model
|
7 |
+
from transformers import AutoImageProcessor, SwinForImageClassification
|
8 |
+
import torch.nn as nn
|
9 |
+
import os
|
10 |
+
import pandas as pd
|
11 |
+
import random
|
12 |
+
|
13 |
+
# Load labels
|
14 |
+
with open("labels.json", "r") as f:
|
15 |
+
class_names = json.load(f)
|
16 |
+
print("class_names:", class_names)
|
17 |
+
|
18 |
+
# Load model
|
19 |
+
|
20 |
+
model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224")
|
21 |
+
|
22 |
+
model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names))
|
23 |
+
|
24 |
+
state_dict = torch.load("best_model.pth", map_location="cpu")
|
25 |
+
|
26 |
+
# Remove incompatible keys (classifier weights)
|
27 |
+
filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}
|
28 |
+
model.load_state_dict(filtered_state_dict, strict=False)
|
29 |
+
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
|
33 |
+
# Image transform
|
34 |
+
# transform = transforms.Compose([
|
35 |
+
# transforms.Resize((224, 224)),
|
36 |
+
# transforms.ToTensor(),
|
37 |
+
# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
38 |
+
# ])
|
39 |
+
#Swin
|
40 |
+
transform = transforms.Compose([
|
41 |
+
transforms.Resize((224, 224)),
|
42 |
+
transforms.ToTensor(),
|
43 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
44 |
+
std=[0.229, 0.224, 0.225])
|
45 |
+
])
|
46 |
+
|
47 |
+
|
48 |
+
def predict(image_path):
|
49 |
+
# Load and prepare image
|
50 |
+
image = Image.open(image_path).convert("RGB")
|
51 |
+
x = transform(image).unsqueeze(0)
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
outputs = model(x)
|
55 |
+
print("Logits:", outputs.logits)
|
56 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
|
57 |
+
print("Probs:", probs)
|
58 |
+
print("Sum of probs:", probs.sum())
|
59 |
+
top5 = torch.topk(probs, k=5)
|
60 |
+
|
61 |
+
top1_idx = int(top5.indices[0])
|
62 |
+
top1_label = class_names[top1_idx]
|
63 |
+
|
64 |
+
# Select a random image from the class subfolder
|
65 |
+
class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
|
66 |
+
reference_image = None
|
67 |
+
if os.path.isdir(class_folder):
|
68 |
+
# List all image files in the folder
|
69 |
+
image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
|
70 |
+
if image_files:
|
71 |
+
chosen_file = random.choice(image_files)
|
72 |
+
ref_path = os.path.join(class_folder, chosen_file)
|
73 |
+
print(f"[DEBUG] Randomly selected reference image: {ref_path}")
|
74 |
+
reference_image = Image.open(ref_path).convert("RGB")
|
75 |
+
else:
|
76 |
+
print(f"[DEBUG] No images found in {class_folder}")
|
77 |
+
else:
|
78 |
+
print(f"[DEBUG] Class folder does not exist: {class_folder}")
|
79 |
+
|
80 |
+
# Format Top-5 for gr.Label with class names
|
81 |
+
top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
|
82 |
+
print(f"image path: {image_path}")
|
83 |
+
print(f"top1_label: {top1_label}")
|
84 |
+
print(f"[DEBUG] Top-5 indices: {top5.indices}")
|
85 |
+
print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
|
86 |
+
print(f"[DEBUG] Top-5 probs: {top5_probs}")
|
87 |
+
|
88 |
+
return image, reference_image, top5_probs
|
89 |
+
|
90 |
+
|
91 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
transformers
|
4 |
+
Pillow
|
5 |
+
gradio
|
6 |
+
numpy
|
sample_images/airport_terminal/airport-check-in.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/airport_terminal/checkin.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/amphitheatre/amphitheatre.png
ADDED
![]() |
Git LFS Details
|
sample_images/amusement_park/A Swinger Ride.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/amusement_park/airport-check-in.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/art_gallery/art_gallery.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bakery_shop/ShopInterior.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bar/Ram02.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bookstore/book_store.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/botanical_garden/botanical_garden.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bridge/Medieval_Exe_Bridge_Exeter.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bridge/ironbridge3.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bridge/millers.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bus_interior/29477487945_81aabab695_b.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/bus_interior/37139142640_3807d91aea_b.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/butchers_shop/butcher_shop.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/campsite/camp_site.png
ADDED
![]() |
Git LFS Details
|
sample_images/classroom/classroom.png
ADDED
![]() |
Git LFS Details
|
sample_images/coffee_shop/the-terrace4.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/construction_site/construction_site.jpeg
ADDED
![]() |
Git LFS Details
|
sample_images/courtyard/courtyard.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/driveway/driveway.jpeg
ADDED
![]() |
Git LFS Details
|
sample_images/fire_station/firestation.jpeg
ADDED
![]() |
Git LFS Details
|
sample_images/fountain/fountain.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/gas_station/gas_station.png
ADDED
![]() |
Git LFS Details
|
sample_images/harbour/Harbour.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/highway/highway.png
ADDED
![]() |
Git LFS Details
|
sample_images/kindergarten_classroom/kindergarden_classroon.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/lobby/lobby.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/market_outdoor/img_7421.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/market_outdoor/www.visitexeter.com.jpeg
ADDED
![]() |
Git LFS Details
|
sample_images/museum/7.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/museum/albert-queen-1-5.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/museum/img2295_1.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/office/images.jpeg
ADDED
![]() |
Git LFS Details
|
sample_images/office/images2.jpeg
ADDED
![]() |
Git LFS Details
|
sample_images/parking_lot/parking_lot.png
ADDED
![]() |
Git LFS Details
|
sample_images/phone_booth/phone_booth.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/playground/3146371_077d0213.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/playground/exeter_hall2.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/playground/planet2.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/railroad_track/rail_road_track.jpg
ADDED
![]() |
Git LFS Details
|
sample_images/restaurant/Boston-Tea-Party.jpg
ADDED
![]() |
Git LFS Details
|