Spaces:
Sleeping
Sleeping
Commit
·
785ef2b
0
Parent(s):
Initial commit
Browse files- dl_supervised_pipeline.py +158 -0
- run_gradio.py +55 -0
- svm_pipeline.py +100 -0
- utils/MAE.py +253 -0
- utils/__init__.py +1 -0
- utils/__pycache__/MAE.cpython-311.pyc +0 -0
- utils/__pycache__/MAE.cpython-38.pyc +0 -0
- utils/__pycache__/MAE.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-310.pyc:Zone.Identifier +3 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/__init__.cpython-39.pyc:Zone.Identifier +3 -0
- utils/__pycache__/arg_utils.cpython-38.pyc +0 -0
- utils/__pycache__/arg_utils.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/arg_utils.cpython-39.pyc +0 -0
- utils/__pycache__/arg_utils.cpython-39.pyc:Zone.Identifier +3 -0
- utils/__pycache__/experiment_utils.cpython-311.pyc +0 -0
- utils/__pycache__/experiment_utils.cpython-38.pyc +0 -0
- utils/__pycache__/experiment_utils.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/experiment_utils.cpython-39.pyc +0 -0
- utils/__pycache__/experiment_utils.cpython-39.pyc:Zone.Identifier +3 -0
- utils/__pycache__/model_utils.cpython-311.pyc +0 -0
- utils/__pycache__/model_utils.cpython-38.pyc +0 -0
- utils/__pycache__/model_utils.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/util_function.cpython-310.pyc +0 -0
- utils/__pycache__/util_function.cpython-310.pyc:Zone.Identifier +3 -0
- utils/__pycache__/util_function.cpython-311.pyc +0 -0
- utils/__pycache__/util_function.cpython-38.pyc +0 -0
- utils/__pycache__/util_function.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/util_function.cpython-39.pyc +0 -0
- utils/__pycache__/util_function.cpython-39.pyc:Zone.Identifier +3 -0
- utils/arg_utils.py +18 -0
- utils/experiment_utils.py +298 -0
- utils/model_utils.py +96 -0
- utils/util_function.py +238 -0
- vis_confusion_mtx.py +54 -0
- vote_analysis.py +107 -0
dl_supervised_pipeline.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code modified from pytorch-image-classification
|
| 2 |
+
# obtained from https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/5_resnet.ipynb#scrollTo=4QmwmcXuPuLo
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import torch.optim.lr_scheduler as lr_scheduler
|
| 10 |
+
|
| 11 |
+
import torch.utils.data as data
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import random
|
| 15 |
+
import tqdm
|
| 16 |
+
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
from data_utils.data_tribology import TribologyDataset
|
| 20 |
+
from utils.experiment_utils import get_model, get_name, get_logger, train, evaluate, evaluate_vote
|
| 21 |
+
from utils.arg_utils import get_args
|
| 22 |
+
|
| 23 |
+
def main(args):
|
| 24 |
+
'''Reproducibility'''
|
| 25 |
+
SEED = args.seed
|
| 26 |
+
random.seed(SEED)
|
| 27 |
+
np.random.seed(SEED)
|
| 28 |
+
torch.manual_seed(SEED)
|
| 29 |
+
torch.cuda.manual_seed(SEED)
|
| 30 |
+
torch.backends.cudnn.deterministic = True
|
| 31 |
+
torch.backends.cudnn.benchmark = False
|
| 32 |
+
|
| 33 |
+
'''Folder Creation'''
|
| 34 |
+
basepath=os.getcwd()
|
| 35 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
| 36 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
| 38 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
'''Logging'''
|
| 41 |
+
model_name = get_name(args)
|
| 42 |
+
print(model_name, 'STARTED')
|
| 43 |
+
if os.path.exists(checkpoint_dir / 'epoch10.pth'):
|
| 44 |
+
print('CHECKPOINT FOUND')
|
| 45 |
+
print('TERMINATING TRAINING')
|
| 46 |
+
return 0 # terminate training if checkpoint exists
|
| 47 |
+
|
| 48 |
+
logger = get_logger(experiment_dir, model_name)
|
| 49 |
+
|
| 50 |
+
'''Data Loading'''
|
| 51 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
| 52 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
| 53 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
| 54 |
+
|
| 55 |
+
# results_acc_1 = {}
|
| 56 |
+
# results_acc_3 = {}
|
| 57 |
+
# classes_num = 6
|
| 58 |
+
BATCHSIZE = args.batch_size
|
| 59 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
| 60 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
| 61 |
+
|
| 62 |
+
# prepare the data augmentation
|
| 63 |
+
means, stds = train_dataset.get_statistics()
|
| 64 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
| 65 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
| 66 |
+
|
| 67 |
+
VALID_RATIO = 0.1
|
| 68 |
+
|
| 69 |
+
num_train = len(train_dataset)
|
| 70 |
+
num_valid = int(VALID_RATIO * num_train)
|
| 71 |
+
train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
|
| 72 |
+
logger.info(f'Number of training samples: {len(train_dataset)}')
|
| 73 |
+
logger.info(f'Number of validation samples: {len(valid_dataset)}')
|
| 74 |
+
train_iterator = torch.utils.data.DataLoader(train_dataset,
|
| 75 |
+
batch_size=BATCHSIZE,
|
| 76 |
+
num_workers=4,
|
| 77 |
+
shuffle=True,
|
| 78 |
+
pin_memory=True,
|
| 79 |
+
drop_last=False)
|
| 80 |
+
|
| 81 |
+
valid_iterator = torch.utils.data.DataLoader(valid_dataset,
|
| 82 |
+
batch_size=BATCHSIZE,
|
| 83 |
+
num_workers=4,
|
| 84 |
+
shuffle=True,
|
| 85 |
+
pin_memory=True,
|
| 86 |
+
drop_last=False)
|
| 87 |
+
test_iterator = torch.utils.data.DataLoader(test_dataset,
|
| 88 |
+
batch_size=BATCHSIZE,
|
| 89 |
+
num_workers=4,
|
| 90 |
+
shuffle=False,
|
| 91 |
+
pin_memory=True,
|
| 92 |
+
drop_last=False)
|
| 93 |
+
print('DATA LOADED')
|
| 94 |
+
|
| 95 |
+
# Define model
|
| 96 |
+
model = get_model(args)
|
| 97 |
+
print('MODEL LOADED')
|
| 98 |
+
|
| 99 |
+
# Define optimizer and scheduler
|
| 100 |
+
START_LR = args.start_lr
|
| 101 |
+
optimizer = optim.Adam(model.parameters(), lr=START_LR)
|
| 102 |
+
STEPS_PER_EPOCH = len(train_iterator)
|
| 103 |
+
print('STEPS_PER_EPOCH:', STEPS_PER_EPOCH)
|
| 104 |
+
print('VALIDATION STEPS:', len(valid_iterator))
|
| 105 |
+
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=max(STEPS_PER_EPOCH,STEPS_PER_EPOCH//10))
|
| 106 |
+
|
| 107 |
+
# Define loss function
|
| 108 |
+
criterion = nn.CrossEntropyLoss()
|
| 109 |
+
|
| 110 |
+
# Define device
|
| 111 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 112 |
+
model = model.to(device)
|
| 113 |
+
criterion = criterion.to(device)
|
| 114 |
+
|
| 115 |
+
EPOCHS = args.epochs
|
| 116 |
+
|
| 117 |
+
print('SETUP DONE')
|
| 118 |
+
# train our model
|
| 119 |
+
|
| 120 |
+
print('TRAINING STARTED')
|
| 121 |
+
for epoch in tqdm.tqdm(range(EPOCHS)):
|
| 122 |
+
|
| 123 |
+
train_loss, train_acc_1, train_acc_3 = train(model, train_iterator, optimizer, criterion, scheduler, device)
|
| 124 |
+
|
| 125 |
+
torch.cuda.empty_cache() # clear cache between train and val
|
| 126 |
+
|
| 127 |
+
valid_loss, valid_acc_1, valid_acc_3 = evaluate(model, valid_iterator, criterion, device)
|
| 128 |
+
|
| 129 |
+
torch.save(model.state_dict(), checkpoint_dir / f'epoch{epoch+1}.pth')
|
| 130 |
+
|
| 131 |
+
logger.info(f'Epoch: {epoch + 1:02}')
|
| 132 |
+
logger.info(f'\tTrain Loss: {train_loss:.3f} | Train Acc @1: {train_acc_1 * 100:6.2f}% | ' \
|
| 133 |
+
f'Train Acc @3: {train_acc_3 * 100:6.2f}%')
|
| 134 |
+
logger.info(f'\tValid Loss: {valid_loss:.3f} | Valid Acc @1: {valid_acc_1 * 100:6.2f}% | ' \
|
| 135 |
+
f'Valid Acc @3: {valid_acc_3 * 100:6.2f}%')
|
| 136 |
+
|
| 137 |
+
logger.info('-------------------End of Training-------------------')
|
| 138 |
+
print('TRAINING DONE')
|
| 139 |
+
logger.info('-------------------Beginning of Testing-------------------')
|
| 140 |
+
print('TESTING STARTED')
|
| 141 |
+
for epoch in tqdm.tqdm(range(EPOCHS)):
|
| 142 |
+
model.load_state_dict(torch.load(checkpoint_dir / f'epoch{epoch+1}.pth'))
|
| 143 |
+
|
| 144 |
+
if args.vote == 'vote':
|
| 145 |
+
test_acc = evaluate_vote(model, test_iterator, device)
|
| 146 |
+
logger.info(f'Test Acc @1: {test_acc * 100:6.2f}%')
|
| 147 |
+
else:
|
| 148 |
+
test_loss, test_acc_1, test_acc_3 = evaluate(model, test_iterator, criterion, device)
|
| 149 |
+
|
| 150 |
+
logger.info(f'Test Acc @1: {test_acc_1 * 100:6.2f}% | ' \
|
| 151 |
+
f'Test Acc @3: {test_acc_3 * 100:6.2f}%')
|
| 152 |
+
logger.info('-------------------End of Testing-------------------')
|
| 153 |
+
print('TESTING DONE')
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == '__main__':
|
| 157 |
+
args = get_args()
|
| 158 |
+
main(args)
|
run_gradio.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
from utils.experiment_utils import get_model
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# 加载DINOv2模型
|
| 8 |
+
def load_model():
|
| 9 |
+
class Args:
|
| 10 |
+
model = 'DINOv2'
|
| 11 |
+
pretrained = 'pretrained'
|
| 12 |
+
frozen = 'unfrozen'
|
| 13 |
+
|
| 14 |
+
args = Args()
|
| 15 |
+
model = get_model(args)
|
| 16 |
+
model.eval()
|
| 17 |
+
return model
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
model = load_model()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# 预测函数,返回每个类别的概率
|
| 24 |
+
def predict(image):
|
| 25 |
+
transform = torchvision.transforms.Compose([
|
| 26 |
+
torchvision.transforms.Resize((224, 224)),
|
| 27 |
+
torchvision.transforms.ToTensor(),
|
| 28 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 29 |
+
])
|
| 30 |
+
|
| 31 |
+
image = transform(image).unsqueeze(0)
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
output = model(image)
|
| 34 |
+
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
|
| 35 |
+
|
| 36 |
+
# 类别名称列表
|
| 37 |
+
class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
|
| 38 |
+
|
| 39 |
+
# 将类别和对应的概率配对
|
| 40 |
+
results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
| 41 |
+
|
| 42 |
+
return results
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# 创建Gradio界面
|
| 46 |
+
interface = gr.Interface(
|
| 47 |
+
fn=predict,
|
| 48 |
+
inputs=gr.Image(type="pil"),
|
| 49 |
+
outputs=gr.Label(num_top_classes=len(["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"])),
|
| 50 |
+
title="LUWA DINOv2 Prediction",
|
| 51 |
+
description="Upload an image to get the probabilities for each class using the DINOv2 model."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
interface.launch(share=True)
|
svm_pipeline.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn.svm import LinearSVC
|
| 3 |
+
|
| 4 |
+
from skimage.feature import fisher_vector, learn_gmm
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from data_utils.data_tribology import TribologyDataset
|
| 12 |
+
from utils.arg_utils import get_args
|
| 13 |
+
from utils.experiment_utils import get_name, get_logger, SIFT_extraction, conduct_voting
|
| 14 |
+
from utils.visualization_utils import plot_confusion_matrix
|
| 15 |
+
from vis_confusion_mtx import generate_confusion_matrix
|
| 16 |
+
|
| 17 |
+
def main(args):
|
| 18 |
+
'''Reproducibility'''
|
| 19 |
+
SEED = args.seed
|
| 20 |
+
random.seed(SEED)
|
| 21 |
+
np.random.seed(SEED)
|
| 22 |
+
|
| 23 |
+
'''Folder Creation'''
|
| 24 |
+
basepath=os.getcwd()
|
| 25 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.vote))
|
| 26 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
| 27 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
| 28 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
'''Logging'''
|
| 31 |
+
model_name = get_name(args)
|
| 32 |
+
print(model_name, 'STARTED', flush=True)
|
| 33 |
+
logger = get_logger(experiment_dir, model_name)
|
| 34 |
+
|
| 35 |
+
'''Data Loading'''
|
| 36 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
| 37 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
| 38 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
| 39 |
+
|
| 40 |
+
BATCHSIZE = args.batch_size
|
| 41 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
| 42 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
| 43 |
+
|
| 44 |
+
# prepare the data augmentation
|
| 45 |
+
means, stds = train_dataset.get_statistics()
|
| 46 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
| 47 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
| 48 |
+
|
| 49 |
+
VALID_RATIO = 0.1
|
| 50 |
+
|
| 51 |
+
num_train = len(train_dataset)
|
| 52 |
+
num_valid = int(VALID_RATIO * num_train)
|
| 53 |
+
# train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
|
| 54 |
+
# logger.info(f'Number of training samples: {len(train_dataset)}')
|
| 55 |
+
# logger.info(f'Number of validation samples: {len(valid_dataset)}')
|
| 56 |
+
|
| 57 |
+
train_names, train_descriptor, train_labels = SIFT_extraction(train_dataset)
|
| 58 |
+
test_names, test_descriptor, test_labels = SIFT_extraction(test_dataset)
|
| 59 |
+
# val_descriptor, val_labels = SIFT_extraction(valid_dataset)
|
| 60 |
+
print('DATA LOADED', flush=True)
|
| 61 |
+
|
| 62 |
+
print('TRAINING STARTED', flush=True)
|
| 63 |
+
|
| 64 |
+
# Train a K-mode GMM
|
| 65 |
+
k = 16
|
| 66 |
+
gmm = learn_gmm(train_descriptor, n_modes=k)
|
| 67 |
+
|
| 68 |
+
# Compute the Fisher vectors
|
| 69 |
+
training_fvs = np.array([
|
| 70 |
+
fisher_vector(descriptor_mat, gmm)
|
| 71 |
+
for descriptor_mat in train_descriptor
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
testing_fvs = np.array([
|
| 75 |
+
fisher_vector(descriptor_mat, gmm)
|
| 76 |
+
for descriptor_mat in test_descriptor
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
svm = LinearSVC().fit(training_fvs, train_labels)
|
| 80 |
+
|
| 81 |
+
logger.info('-------------------End of Training-------------------')
|
| 82 |
+
print('TRAINING DONE')
|
| 83 |
+
logger.info('-------------------Beginning of Testing-------------------')
|
| 84 |
+
print('TESTING STARTED')
|
| 85 |
+
predictions = svm.predict(testing_fvs)
|
| 86 |
+
conduct_voting(test_names, predictions)
|
| 87 |
+
plot_confusion_matrix('visualization_results/SIFT+FVs_confusion_mtx.png', predictions, test_labels,classes=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY","SPRUCEWOOD"])
|
| 88 |
+
correct = 0
|
| 89 |
+
for i in range(len(predictions)):
|
| 90 |
+
if predictions[i] == test_labels[i]:
|
| 91 |
+
correct += 1
|
| 92 |
+
test_acc = float(correct)/len(predictions)
|
| 93 |
+
logger.info(f'Test Acc @1: {test_acc * 100:6.2f}%')
|
| 94 |
+
|
| 95 |
+
logger.info('-------------------End of Testing-------------------')
|
| 96 |
+
print('TESTING DONE')
|
| 97 |
+
|
| 98 |
+
if __name__ == '__main__':
|
| 99 |
+
args = get_args()
|
| 100 |
+
main(args)
|
utils/MAE.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
| 18 |
+
|
| 19 |
+
from utils.model_utils import get_2d_sincos_pos_embed
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MaskedAutoencoderViT(nn.Module):
|
| 23 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3,
|
| 26 |
+
embed_dim=1024, depth=24, num_heads=16,
|
| 27 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 28 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
# --------------------------------------------------------------------------
|
| 32 |
+
# MAE encoder specifics
|
| 33 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
| 34 |
+
num_patches = self.patch_embed.num_patches
|
| 35 |
+
|
| 36 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 37 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 38 |
+
|
| 39 |
+
self.blocks = nn.ModuleList([
|
| 40 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
| 41 |
+
for i in range(depth)])
|
| 42 |
+
self.norm = norm_layer(embed_dim)
|
| 43 |
+
# --------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
# --------------------------------------------------------------------------
|
| 46 |
+
# MAE decoder specifics
|
| 47 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
| 48 |
+
|
| 49 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 50 |
+
|
| 51 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 52 |
+
|
| 53 |
+
self.decoder_blocks = nn.ModuleList([
|
| 54 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
| 55 |
+
for i in range(decoder_depth)])
|
| 56 |
+
|
| 57 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 58 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
|
| 59 |
+
# --------------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
self.norm_pix_loss = norm_pix_loss
|
| 62 |
+
|
| 63 |
+
self.initialize_weights()
|
| 64 |
+
|
| 65 |
+
def initialize_weights(self):
|
| 66 |
+
# initialization
|
| 67 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 68 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 69 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 70 |
+
|
| 71 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 72 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 73 |
+
|
| 74 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 75 |
+
w = self.patch_embed.proj.weight.data
|
| 76 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 77 |
+
|
| 78 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 79 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
| 80 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 81 |
+
|
| 82 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 83 |
+
self.apply(self._init_weights)
|
| 84 |
+
|
| 85 |
+
def _init_weights(self, m):
|
| 86 |
+
if isinstance(m, nn.Linear):
|
| 87 |
+
# we use xavier_uniform following official JAX ViT:
|
| 88 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 89 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 90 |
+
nn.init.constant_(m.bias, 0)
|
| 91 |
+
elif isinstance(m, nn.LayerNorm):
|
| 92 |
+
nn.init.constant_(m.bias, 0)
|
| 93 |
+
nn.init.constant_(m.weight, 1.0)
|
| 94 |
+
|
| 95 |
+
def patchify(self, imgs):
|
| 96 |
+
"""
|
| 97 |
+
imgs: (N, 3, H, W)
|
| 98 |
+
x: (N, L, patch_size**2 *3)
|
| 99 |
+
"""
|
| 100 |
+
p = self.patch_embed.patch_size[0]
|
| 101 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
| 102 |
+
|
| 103 |
+
h = w = imgs.shape[2] // p
|
| 104 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
| 105 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
| 106 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
def unpatchify(self, x):
|
| 110 |
+
"""
|
| 111 |
+
x: (N, L, patch_size**2 *3)
|
| 112 |
+
imgs: (N, 3, H, W)
|
| 113 |
+
"""
|
| 114 |
+
p = self.patch_embed.patch_size[0]
|
| 115 |
+
h = w = int(x.shape[1]**.5)
|
| 116 |
+
assert h * w == x.shape[1]
|
| 117 |
+
|
| 118 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
| 119 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 120 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
| 121 |
+
return imgs
|
| 122 |
+
|
| 123 |
+
def random_masking(self, x, mask_ratio):
|
| 124 |
+
"""
|
| 125 |
+
Perform per-sample random masking by per-sample shuffling.
|
| 126 |
+
Per-sample shuffling is done by argsort random noise.
|
| 127 |
+
x: [N, L, D], sequence
|
| 128 |
+
"""
|
| 129 |
+
N, L, D = x.shape # batch, length, dim
|
| 130 |
+
len_keep = int(L * (1 - mask_ratio))
|
| 131 |
+
|
| 132 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
| 133 |
+
|
| 134 |
+
# sort noise for each sample
|
| 135 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 136 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 137 |
+
|
| 138 |
+
# keep the first subset
|
| 139 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 140 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
| 141 |
+
|
| 142 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 143 |
+
mask = torch.ones([N, L], device=x.device)
|
| 144 |
+
mask[:, :len_keep] = 0
|
| 145 |
+
# unshuffle to get the binary mask
|
| 146 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 147 |
+
|
| 148 |
+
return x_masked, mask, ids_restore
|
| 149 |
+
|
| 150 |
+
def forward_encoder(self, x, mask_ratio):
|
| 151 |
+
# embed patches
|
| 152 |
+
x = self.patch_embed(x)
|
| 153 |
+
|
| 154 |
+
# add pos embed w/o cls token
|
| 155 |
+
x = x + self.pos_embed[:, 1:, :]
|
| 156 |
+
|
| 157 |
+
# masking: length -> length * mask_ratio
|
| 158 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
| 159 |
+
|
| 160 |
+
# append cls token
|
| 161 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
| 162 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 163 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 164 |
+
|
| 165 |
+
# apply Transformer blocks
|
| 166 |
+
for blk in self.blocks:
|
| 167 |
+
x = blk(x)
|
| 168 |
+
x = self.norm(x)
|
| 169 |
+
|
| 170 |
+
return x, mask, ids_restore
|
| 171 |
+
|
| 172 |
+
def forward_decoder(self, x, ids_restore):
|
| 173 |
+
# embed tokens
|
| 174 |
+
x = self.decoder_embed(x)
|
| 175 |
+
|
| 176 |
+
# append mask tokens to sequence
|
| 177 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
| 178 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
| 179 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
| 180 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
| 181 |
+
|
| 182 |
+
# add pos embed
|
| 183 |
+
x = x + self.decoder_pos_embed
|
| 184 |
+
|
| 185 |
+
# apply Transformer blocks
|
| 186 |
+
for blk in self.decoder_blocks:
|
| 187 |
+
x = blk(x)
|
| 188 |
+
x = self.decoder_norm(x)
|
| 189 |
+
|
| 190 |
+
# predictor projection
|
| 191 |
+
x = self.decoder_pred(x)
|
| 192 |
+
|
| 193 |
+
# remove cls token
|
| 194 |
+
x = x[:, 1:, :]
|
| 195 |
+
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
def forward_loss(self, imgs, pred, mask):
|
| 199 |
+
"""
|
| 200 |
+
imgs: [N, 3, H, W]
|
| 201 |
+
pred: [N, L, p*p*3]
|
| 202 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
| 203 |
+
"""
|
| 204 |
+
target = self.patchify(imgs)
|
| 205 |
+
if self.norm_pix_loss:
|
| 206 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 207 |
+
var = target.var(dim=-1, keepdim=True)
|
| 208 |
+
target = (target - mean) / (var + 1.e-6)**.5
|
| 209 |
+
|
| 210 |
+
loss = (pred - target) ** 2
|
| 211 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 212 |
+
|
| 213 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 214 |
+
return loss
|
| 215 |
+
|
| 216 |
+
def forward(self, imgs, mask_ratio=0.75):
|
| 217 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
| 218 |
+
# pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
| 219 |
+
# loss = self.forward_loss(imgs, pred, mask)
|
| 220 |
+
# return loss, pred, mask
|
| 221 |
+
print(latent.shape)
|
| 222 |
+
return latent
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def mae_vit_base_patch16_dec512d8b(**kwargs):
|
| 227 |
+
model = MaskedAutoencoderViT(
|
| 228 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
| 229 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 230 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 231 |
+
return model
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def mae_vit_large_patch16_dec512d8b(**kwargs):
|
| 235 |
+
model = MaskedAutoencoderViT(
|
| 236 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 237 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 238 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 239 |
+
return model
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def mae_vit_huge_patch14_dec512d8b(**kwargs):
|
| 243 |
+
model = MaskedAutoencoderViT(
|
| 244 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
|
| 245 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 246 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 247 |
+
return model
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# set recommended archs
|
| 251 |
+
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
| 252 |
+
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
| 253 |
+
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .util_function import epoch_time, plot_lr_finder, plot_confusion_matrix, plot_most_incorrect, get_pca, plot_representations, plot_filtered_images, plot_filters
|
utils/__pycache__/MAE.cpython-311.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
utils/__pycache__/MAE.cpython-38.pyc
ADDED
|
Binary file (7.16 kB). View file
|
|
|
utils/__pycache__/MAE.cpython-38.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (422 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-310.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (516 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (376 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-38.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (392 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-39.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/arg_utils.cpython-38.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
utils/__pycache__/arg_utils.cpython-38.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/arg_utils.cpython-39.pyc
ADDED
|
Binary file (1.05 kB). View file
|
|
|
utils/__pycache__/arg_utils.cpython-39.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/experiment_utils.cpython-311.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
utils/__pycache__/experiment_utils.cpython-38.pyc
ADDED
|
Binary file (5.71 kB). View file
|
|
|
utils/__pycache__/experiment_utils.cpython-38.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/experiment_utils.cpython-39.pyc
ADDED
|
Binary file (4.93 kB). View file
|
|
|
utils/__pycache__/experiment_utils.cpython-39.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/model_utils.cpython-311.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
utils/__pycache__/model_utils.cpython-38.pyc
ADDED
|
Binary file (2.4 kB). View file
|
|
|
utils/__pycache__/model_utils.cpython-38.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/util_function.cpython-310.pyc
ADDED
|
Binary file (5.35 kB). View file
|
|
|
utils/__pycache__/util_function.cpython-310.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/util_function.cpython-311.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
utils/__pycache__/util_function.cpython-38.pyc
ADDED
|
Binary file (6.8 kB). View file
|
|
|
utils/__pycache__/util_function.cpython-38.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/util_function.cpython-39.pyc
ADDED
|
Binary file (6.82 kB). View file
|
|
|
utils/__pycache__/util_function.cpython-39.pyc:Zone.Identifier
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ZoneTransfer]
|
| 2 |
+
ZoneId=3
|
| 3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/arg_utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
def get_args():
|
| 4 |
+
# Training settings
|
| 5 |
+
parser = argparse.ArgumentParser('train')
|
| 6 |
+
|
| 7 |
+
parser.add_argument('--resolution', type=str, default='256', help='Resolution of input image')
|
| 8 |
+
parser.add_argument('--magnification', type=str, default='20x', help='Magnification of input image')
|
| 9 |
+
parser.add_argument('--modality', type=str, default='texture', help='Modality of input image')
|
| 10 |
+
parser.add_argument('--model', type=str, default='ResNet50', help='Model to use')
|
| 11 |
+
parser.add_argument('--pretrained', type=str, default='pretrained', help='Use pretrained model')
|
| 12 |
+
parser.add_argument('--frozen', type=str, default='unfrozen', help='Freeze pretrained model')
|
| 13 |
+
parser.add_argument('--vote', type=str, default='vote', help='Conduct voting')
|
| 14 |
+
parser.add_argument('--epochs', type=int, default=2, help='Number of epochs to train')
|
| 15 |
+
parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
|
| 16 |
+
parser.add_argument('--start_lr', type=float, default=0.01, help='Learning rate')
|
| 17 |
+
parser.add_argument('--seed', type=int, default=1234, help='Random seed')
|
| 18 |
+
return parser.parse_args()
|
utils/experiment_utils.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import logging
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from utils.MAE import mae_vit_large_patch16_dec512d8b as MAE_large
|
| 8 |
+
|
| 9 |
+
def get_model(args) -> nn.Module:
|
| 10 |
+
if 'ResNet' in args.model:
|
| 11 |
+
# resnet family
|
| 12 |
+
if args.model == 'ResNet50':
|
| 13 |
+
if args.pretrained == 'pretrained':
|
| 14 |
+
model = torchvision.models.resnet50(weights='IMAGENET1K_V2')
|
| 15 |
+
else:
|
| 16 |
+
model = torchvision.models.resnet50()
|
| 17 |
+
elif args.model == 'ResNet152':
|
| 18 |
+
if args.pretrained == 'pretrained':
|
| 19 |
+
model = torchvision.models.resnet152(weights='IMAGENET1K_V2')
|
| 20 |
+
else:
|
| 21 |
+
model = torchvision.models.resnet152()
|
| 22 |
+
else:
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
if args.frozen == 'frozen':
|
| 25 |
+
model = freeze_backbone(model)
|
| 26 |
+
model.fc = nn.Linear(model.fc.in_features, 6)
|
| 27 |
+
|
| 28 |
+
elif 'ConvNext' in args.model:
|
| 29 |
+
if args.model == 'ConvNext_Tiny':
|
| 30 |
+
if args.pretrained == 'pretrained':
|
| 31 |
+
model = torchvision.models.convnext_tiny(weights='IMAGENET1K_V1')
|
| 32 |
+
else:
|
| 33 |
+
model = torchvision.models.convnext_tiny()
|
| 34 |
+
elif args.model == 'ConvNext_Large':
|
| 35 |
+
if args.pretrained == 'pretrained':
|
| 36 |
+
model = torchvision.models.convnext_large(weights='IMAGENET1K_V1')
|
| 37 |
+
else:
|
| 38 |
+
model = torchvision.models.convnext_large()
|
| 39 |
+
else:
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
if args.frozen == 'frozen':
|
| 42 |
+
model = freeze_backbone(model)
|
| 43 |
+
num_ftrs = model.classifier[2].in_features
|
| 44 |
+
model.classifier[2] = nn.Linear(int(num_ftrs), 6)
|
| 45 |
+
|
| 46 |
+
elif 'ViT' in args.model:
|
| 47 |
+
if args.pretrained == 'pretrained':
|
| 48 |
+
model = torchvision.models.vit_h_14(weights='IMAGENET1K_SWAG_LINEAR_V1')
|
| 49 |
+
else:
|
| 50 |
+
raise NotImplementedError('ViT does not support training from scratch')
|
| 51 |
+
if args.frozen == 'frozen':
|
| 52 |
+
model = freeze_backbone(model)
|
| 53 |
+
model.heads[0] = torch.nn.Linear(model.heads[0].in_features, 6)
|
| 54 |
+
|
| 55 |
+
elif 'DINOv2' in args.model:
|
| 56 |
+
if args.pretrained == 'pretrained':
|
| 57 |
+
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')
|
| 58 |
+
else:
|
| 59 |
+
raise NotImplementedError('DINOv2 does not support training from scratch')
|
| 60 |
+
if args.frozen == 'frozen':
|
| 61 |
+
model = freeze_backbone(model)
|
| 62 |
+
model.linear_head = torch.nn.Linear(model.linear_head.in_features, 6)
|
| 63 |
+
|
| 64 |
+
elif 'MAE' in args.model:
|
| 65 |
+
if args.pretrained == 'pretrained':
|
| 66 |
+
model = MAE_large()
|
| 67 |
+
model.load_state_dict(torch.load('/scratch/zf540/LUWA/workspace/utils/pretrained_weights/mae_visualize_vit_large.pth')['model'])
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError('MAE does not support training from scratch')
|
| 70 |
+
if args.frozen == 'frozen':
|
| 71 |
+
model = freeze_backbone(model)
|
| 72 |
+
model = nn.Sequential(model, nn.Linear(1024, 6))
|
| 73 |
+
print(model)
|
| 74 |
+
else:
|
| 75 |
+
raise NotImplementedError
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def freeze_backbone(model):
|
| 80 |
+
# freeze backbone
|
| 81 |
+
# we will replace the classifier at the end with a trainable one anyway, so we freeze the default here as well
|
| 82 |
+
for param in model.parameters():
|
| 83 |
+
param.requires_grad = False
|
| 84 |
+
return model
|
| 85 |
+
|
| 86 |
+
def get_name(args):
|
| 87 |
+
name = args.model
|
| 88 |
+
name += '_'+str(args.resolution)
|
| 89 |
+
name += '_'+args.magnification
|
| 90 |
+
name += '_'+args.modality
|
| 91 |
+
if args.pretrained == 'pretrained':
|
| 92 |
+
name += '_pretrained'
|
| 93 |
+
else:
|
| 94 |
+
name += '_scratch'
|
| 95 |
+
if args.frozen == 'frozen':
|
| 96 |
+
name += '_frozen'
|
| 97 |
+
else:
|
| 98 |
+
name += '_unfrozen'
|
| 99 |
+
if args.vote == 'vote':
|
| 100 |
+
name += '_vote'
|
| 101 |
+
else:
|
| 102 |
+
name += '_novote'
|
| 103 |
+
return name
|
| 104 |
+
|
| 105 |
+
def get_logger(path, name):
|
| 106 |
+
# set up logger
|
| 107 |
+
|
| 108 |
+
logger = logging.getLogger(name)
|
| 109 |
+
logger.setLevel(logging.INFO)
|
| 110 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 111 |
+
file_handler = logging.FileHandler(path.joinpath(f'{name}_log.txt'))
|
| 112 |
+
file_handler.setLevel(logging.INFO)
|
| 113 |
+
file_handler.setFormatter(formatter)
|
| 114 |
+
logger.addHandler(file_handler)
|
| 115 |
+
logger.info('---------------------------------------------------TRANING---------------------------------------------------')
|
| 116 |
+
|
| 117 |
+
return logger
|
| 118 |
+
|
| 119 |
+
def calculate_topk_accuracy(y_pred, y, k = 3):
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
batch_size = y.shape[0]
|
| 122 |
+
_, top_pred = y_pred.topk(k, 1)
|
| 123 |
+
top_pred = top_pred.t()
|
| 124 |
+
correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
|
| 125 |
+
correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
|
| 126 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True)
|
| 127 |
+
acc_1 = correct_1 / batch_size
|
| 128 |
+
acc_k = correct_k / batch_size
|
| 129 |
+
return acc_1, acc_k
|
| 130 |
+
|
| 131 |
+
def train(model, iterator, optimizer, criterion, scheduler, device):
|
| 132 |
+
epoch_loss = 0
|
| 133 |
+
epoch_acc_1 = 0
|
| 134 |
+
epoch_acc_3 = 0
|
| 135 |
+
|
| 136 |
+
model.train()
|
| 137 |
+
|
| 138 |
+
for image, label, image_name in iterator:
|
| 139 |
+
x = image.to(device)
|
| 140 |
+
y = label.to(device)
|
| 141 |
+
|
| 142 |
+
optimizer.zero_grad()
|
| 143 |
+
|
| 144 |
+
y_pred = model(x)
|
| 145 |
+
print(y_pred.shape)
|
| 146 |
+
print(y.shape)
|
| 147 |
+
loss = criterion(y_pred, y)
|
| 148 |
+
|
| 149 |
+
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
|
| 150 |
+
|
| 151 |
+
loss.backward()
|
| 152 |
+
|
| 153 |
+
optimizer.step()
|
| 154 |
+
|
| 155 |
+
scheduler.step()
|
| 156 |
+
|
| 157 |
+
epoch_loss += loss.item()
|
| 158 |
+
epoch_acc_1 += acc_1.item()
|
| 159 |
+
epoch_acc_3 += acc_3.item()
|
| 160 |
+
|
| 161 |
+
epoch_loss /= len(iterator)
|
| 162 |
+
epoch_acc_1 /= len(iterator)
|
| 163 |
+
epoch_acc_3 /= len(iterator)
|
| 164 |
+
|
| 165 |
+
return epoch_loss, epoch_acc_1, epoch_acc_3
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def evaluate(model, iterator, criterion, device):
|
| 169 |
+
epoch_loss = 0
|
| 170 |
+
epoch_acc_1 = 0
|
| 171 |
+
epoch_acc_3 = 0
|
| 172 |
+
|
| 173 |
+
model.eval()
|
| 174 |
+
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
for image, label, image_name in iterator:
|
| 177 |
+
x = image.to(device)
|
| 178 |
+
y = label.to(device)
|
| 179 |
+
|
| 180 |
+
y_pred = model(x)
|
| 181 |
+
loss = criterion(y_pred, y)
|
| 182 |
+
|
| 183 |
+
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
|
| 184 |
+
|
| 185 |
+
epoch_loss += loss.item()
|
| 186 |
+
epoch_acc_1 += acc_1.item()
|
| 187 |
+
epoch_acc_3 += acc_3.item()
|
| 188 |
+
|
| 189 |
+
epoch_loss /= len(iterator)
|
| 190 |
+
epoch_acc_1 /= len(iterator)
|
| 191 |
+
epoch_acc_3 /= len(iterator)
|
| 192 |
+
|
| 193 |
+
return epoch_loss, epoch_acc_1, epoch_acc_3
|
| 194 |
+
|
| 195 |
+
def evaluate_vote(model, iterator, device):
|
| 196 |
+
|
| 197 |
+
model.eval()
|
| 198 |
+
|
| 199 |
+
image_names = []
|
| 200 |
+
labels = []
|
| 201 |
+
predictions = []
|
| 202 |
+
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
|
| 205 |
+
for image, label, image_name in iterator:
|
| 206 |
+
|
| 207 |
+
x = image.to(device)
|
| 208 |
+
|
| 209 |
+
y_pred = model(x)
|
| 210 |
+
y_prob = F.softmax(y_pred, dim = -1)
|
| 211 |
+
top_pred = y_prob.argmax(1, keepdim = True)
|
| 212 |
+
|
| 213 |
+
image_names.extend(image_name)
|
| 214 |
+
labels.extend(label.numpy())
|
| 215 |
+
predictions.extend(top_pred.cpu().squeeze().numpy())
|
| 216 |
+
|
| 217 |
+
conduct_voting(image_names, predictions)
|
| 218 |
+
|
| 219 |
+
correct_count = 0
|
| 220 |
+
for i in range(len(labels)):
|
| 221 |
+
if labels[i] == predictions[i]:
|
| 222 |
+
correct_count += 1
|
| 223 |
+
accuracy = correct_count/len(labels)
|
| 224 |
+
return accuracy
|
| 225 |
+
|
| 226 |
+
def conduct_voting(image_names, predictions):
|
| 227 |
+
# we need to do this because not all stones have the same number of partition
|
| 228 |
+
last_stone = image_names[0][:-8] # the name of the stone of the last image
|
| 229 |
+
voting_list = []
|
| 230 |
+
for i in range(len(image_names)):
|
| 231 |
+
image_area_name = image_names[i][:-8]
|
| 232 |
+
if image_area_name != last_stone:
|
| 233 |
+
# we have run through all the images of the last stone. We start voting
|
| 234 |
+
vote(voting_list, predictions, i)
|
| 235 |
+
voting_list = [] # reset the voting list
|
| 236 |
+
voting_list.append(predictions[i])
|
| 237 |
+
last_stone = image_area_name # update the last stone name
|
| 238 |
+
|
| 239 |
+
# vote for the last stone
|
| 240 |
+
vote(voting_list, predictions, len(image_names))
|
| 241 |
+
|
| 242 |
+
def vote(voting_list, predictions, i):
|
| 243 |
+
vote_result = Counter(voting_list).most_common(1)[0][0] # the most common prediction in the list
|
| 244 |
+
predictions[i-len(voting_list):i] = [vote_result]*len(voting_list) # replace the predictions of the last stone with the vote result
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# def get_predictions(model, iterator):
|
| 250 |
+
|
| 251 |
+
# model.eval()
|
| 252 |
+
|
| 253 |
+
# images = []
|
| 254 |
+
# labels = []
|
| 255 |
+
# probs = []
|
| 256 |
+
|
| 257 |
+
# with torch.no_grad():
|
| 258 |
+
|
| 259 |
+
# for (x, y) in iterator:
|
| 260 |
+
|
| 261 |
+
# x = x.to(device)
|
| 262 |
+
|
| 263 |
+
# y_pred = model(x)
|
| 264 |
+
|
| 265 |
+
# y_prob = F.softmax(y_pred, dim = -1)
|
| 266 |
+
# top_pred = y_prob.argmax(1, keepdim = True)
|
| 267 |
+
|
| 268 |
+
# images.append(x.cpu())
|
| 269 |
+
# labels.append(y.cpu())
|
| 270 |
+
# probs.append(y_prob.cpu())
|
| 271 |
+
|
| 272 |
+
# images = torch.cat(images, dim = 0)
|
| 273 |
+
# labels = torch.cat(labels, dim = 0)
|
| 274 |
+
# probs = torch.cat(probs, dim = 0)
|
| 275 |
+
|
| 276 |
+
# return images, labels, probs
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# def get_representations(model, iterator):
|
| 280 |
+
# model.eval()
|
| 281 |
+
|
| 282 |
+
# outputs = []
|
| 283 |
+
# intermediates = []
|
| 284 |
+
# labels = []
|
| 285 |
+
|
| 286 |
+
# with torch.no_grad():
|
| 287 |
+
# for (x, y) in iterator:
|
| 288 |
+
# x = x.to(device)
|
| 289 |
+
|
| 290 |
+
# y_pred = model(x)
|
| 291 |
+
|
| 292 |
+
# outputs.append(y_pred.cpu())
|
| 293 |
+
# labels.append(y)
|
| 294 |
+
|
| 295 |
+
# outputs = torch.cat(outputs, dim=0)
|
| 296 |
+
# labels = torch.cat(labels, dim=0)
|
| 297 |
+
|
| 298 |
+
# return outputs, labels
|
utils/model_utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Position embedding utils
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# 2D sine-cosine position embedding
|
| 16 |
+
# References:
|
| 17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 21 |
+
"""
|
| 22 |
+
grid_size: int of the grid height and width
|
| 23 |
+
return:
|
| 24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 25 |
+
"""
|
| 26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 29 |
+
grid = np.stack(grid, axis=0)
|
| 30 |
+
|
| 31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 33 |
+
if cls_token:
|
| 34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 35 |
+
return pos_embed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 39 |
+
assert embed_dim % 2 == 0
|
| 40 |
+
|
| 41 |
+
# use half of dimensions to encode grid_h
|
| 42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 44 |
+
|
| 45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 46 |
+
return emb
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 50 |
+
"""
|
| 51 |
+
embed_dim: output dimension for each position
|
| 52 |
+
pos: a list of positions to be encoded: size (M,)
|
| 53 |
+
out: (M, D)
|
| 54 |
+
"""
|
| 55 |
+
assert embed_dim % 2 == 0
|
| 56 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 57 |
+
omega /= embed_dim / 2.
|
| 58 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 59 |
+
|
| 60 |
+
pos = pos.reshape(-1) # (M,)
|
| 61 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 62 |
+
|
| 63 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 64 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 65 |
+
|
| 66 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 67 |
+
return emb
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --------------------------------------------------------
|
| 71 |
+
# Interpolate position embeddings for high-resolution
|
| 72 |
+
# References:
|
| 73 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 74 |
+
# --------------------------------------------------------
|
| 75 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 76 |
+
if 'pos_embed' in checkpoint_model:
|
| 77 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 78 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 79 |
+
num_patches = model.patch_embed.num_patches
|
| 80 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 81 |
+
# height (== width) for the checkpoint position embedding
|
| 82 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 83 |
+
# height (== width) for the new position embedding
|
| 84 |
+
new_size = int(num_patches ** 0.5)
|
| 85 |
+
# class_token and dist_token are kept unchanged
|
| 86 |
+
if orig_size != new_size:
|
| 87 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 88 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 89 |
+
# only the position tokens are interpolated
|
| 90 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 91 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 92 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 93 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 94 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 95 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 96 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
utils/util_function.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
from sklearn.manifold import TSNE
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from sklearn.metrics import confusion_matrix
|
| 8 |
+
from sklearn.metrics import ConfusionMatrixDisplay
|
| 9 |
+
from sklearn import decomposition
|
| 10 |
+
import itertools
|
| 11 |
+
|
| 12 |
+
def normalize_image(image):
|
| 13 |
+
image_min = image.min()
|
| 14 |
+
image_max = image.max()
|
| 15 |
+
image.clamp_(min = image_min, max = image_max)
|
| 16 |
+
image.add_(-image_min).div_(image_max - image_min + 1e-5)
|
| 17 |
+
return image
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def plot_lr_finder(fig_name, lrs, losses, skip_start=5, skip_end=5):
|
| 21 |
+
if skip_end == 0:
|
| 22 |
+
lrs = lrs[skip_start:]
|
| 23 |
+
losses = losses[skip_start:]
|
| 24 |
+
else:
|
| 25 |
+
lrs = lrs[skip_start:-skip_end]
|
| 26 |
+
losses = losses[skip_start:-skip_end]
|
| 27 |
+
|
| 28 |
+
fig = plt.figure(figsize=(16, 8))
|
| 29 |
+
ax = fig.add_subplot(1, 1, 1)
|
| 30 |
+
ax.plot(lrs, losses)
|
| 31 |
+
ax.set_xscale('log')
|
| 32 |
+
ax.set_xlabel('Learning rate')
|
| 33 |
+
ax.set_ylabel('Loss')
|
| 34 |
+
ax.grid(True, 'both', 'x')
|
| 35 |
+
plt.show()
|
| 36 |
+
plt.savefig(fig_name)
|
| 37 |
+
|
| 38 |
+
def epoch_time(start_time, end_time):
|
| 39 |
+
elapsed_time = end_time - start_time
|
| 40 |
+
elapsed_mins = int(elapsed_time / 60)
|
| 41 |
+
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
|
| 42 |
+
return elapsed_mins, elapsed_secs
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def plot_confusion_matrix(fig_name, labels, pred_labels, classes):
|
| 46 |
+
fig = plt.figure(figsize=(50, 50));
|
| 47 |
+
ax = fig.add_subplot(1, 1, 1);
|
| 48 |
+
cm = confusion_matrix(labels, pred_labels);
|
| 49 |
+
cm = ConfusionMatrixDisplay(cm, display_labels=classes);
|
| 50 |
+
cm.plot(values_format='d', cmap='Blues', ax=ax)
|
| 51 |
+
fig.delaxes(fig.axes[1]) # delete colorbar
|
| 52 |
+
plt.xticks(rotation=90, fontsize=50)
|
| 53 |
+
plt.yticks(fontsize=50)
|
| 54 |
+
plt.rcParams.update({'font.size': 50})
|
| 55 |
+
plt.xlabel('Predicted Label', fontsize=50)
|
| 56 |
+
plt.ylabel('True Label', fontsize=50)
|
| 57 |
+
plt.savefig(fig_name)
|
| 58 |
+
|
| 59 |
+
def plot_confusion_matrix_SVM(fig_name, true_labels, predicted_labels, classes):
|
| 60 |
+
fig = plt.figure(figsize=(100, 100))
|
| 61 |
+
ax = fig.add_subplot(1, 1, 1)
|
| 62 |
+
|
| 63 |
+
cm = confusion_matrix(true_labels, predicted_labels)
|
| 64 |
+
cm_display = ConfusionMatrixDisplay(cm, display_labels=classes)
|
| 65 |
+
|
| 66 |
+
cm_display.plot(values_format='d', cmap='Blues', ax=ax)
|
| 67 |
+
|
| 68 |
+
fig.delaxes(fig.axes[1]) # delete colorbar
|
| 69 |
+
plt.xticks(rotation=90, fontsize=50)
|
| 70 |
+
plt.yticks(fontsize=50)
|
| 71 |
+
plt.rcParams.update({'font.size': 50})
|
| 72 |
+
plt.xlabel('Predicted Label', fontsize=50)
|
| 73 |
+
plt.ylabel('True Label', fontsize=50)
|
| 74 |
+
plt.savefig(fig_name)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def plot_most_incorrect(fig_name, incorrect, classes, n_images, normalize=True):
|
| 78 |
+
rows = int(np.sqrt(n_images))
|
| 79 |
+
cols = int(np.sqrt(n_images))
|
| 80 |
+
|
| 81 |
+
fig = plt.figure(figsize=(25, 20))
|
| 82 |
+
|
| 83 |
+
for i in range(rows * cols):
|
| 84 |
+
|
| 85 |
+
ax = fig.add_subplot(rows, cols, i + 1)
|
| 86 |
+
|
| 87 |
+
image, true_label, probs = incorrect[i]
|
| 88 |
+
image = image.permute(1, 2, 0)
|
| 89 |
+
true_prob = probs[true_label]
|
| 90 |
+
incorrect_prob, incorrect_label = torch.max(probs, dim=0)
|
| 91 |
+
true_class = classes[true_label]
|
| 92 |
+
incorrect_class = classes[incorrect_label]
|
| 93 |
+
|
| 94 |
+
if normalize:
|
| 95 |
+
image = normalize_image(image)
|
| 96 |
+
|
| 97 |
+
ax.imshow(image.cpu().numpy())
|
| 98 |
+
ax.set_title(f'true label: {true_class} ({true_prob:.3f})\n' \
|
| 99 |
+
f'pred label: {incorrect_class} ({incorrect_prob:.3f})')
|
| 100 |
+
ax.axis('off')
|
| 101 |
+
|
| 102 |
+
fig.subplots_adjust(hspace=0.4)
|
| 103 |
+
plt.savefig(fig_name)
|
| 104 |
+
|
| 105 |
+
def get_pca(data, n_components = 2):
|
| 106 |
+
pca = decomposition.PCA()
|
| 107 |
+
pca.n_components = n_components
|
| 108 |
+
pca_data = pca.fit_transform(data)
|
| 109 |
+
return pca_data
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def plot_representations(fig_name, data, labels, classes, n_images=None):
|
| 113 |
+
if n_images is not None:
|
| 114 |
+
data = data[:n_images]
|
| 115 |
+
labels = labels[:n_images]
|
| 116 |
+
|
| 117 |
+
fig = plt.figure(figsize=(15, 15))
|
| 118 |
+
ax = fig.add_subplot(111)
|
| 119 |
+
scatter = ax.scatter(data[:, 0], data[:, 1], c=labels, cmap='hsv')
|
| 120 |
+
# handles, _ = scatter.legend_elements(num = None)
|
| 121 |
+
# legend = plt.legend(handles = handles, labels = classes)
|
| 122 |
+
plt.savefig(fig_name)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def plot_filtered_images(fig_name, images, filters, n_filters = None, normalize = True):
|
| 126 |
+
|
| 127 |
+
images = torch.cat([i.unsqueeze(0) for i in images], dim = 0).cpu()
|
| 128 |
+
filters = filters.cpu()
|
| 129 |
+
|
| 130 |
+
if n_filters is not None:
|
| 131 |
+
filters = filters[:n_filters]
|
| 132 |
+
|
| 133 |
+
n_images = images.shape[0]
|
| 134 |
+
n_filters = filters.shape[0]
|
| 135 |
+
|
| 136 |
+
filtered_images = F.conv2d(images, filters)
|
| 137 |
+
|
| 138 |
+
fig = plt.figure(figsize = (30, 30))
|
| 139 |
+
|
| 140 |
+
for i in range(n_images):
|
| 141 |
+
|
| 142 |
+
image = images[i]
|
| 143 |
+
|
| 144 |
+
if normalize:
|
| 145 |
+
image = normalize_image(image)
|
| 146 |
+
|
| 147 |
+
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters))
|
| 148 |
+
ax.imshow(image.permute(1,2,0).numpy())
|
| 149 |
+
ax.set_title('Original')
|
| 150 |
+
ax.axis('off')
|
| 151 |
+
|
| 152 |
+
for j in range(n_filters):
|
| 153 |
+
image = filtered_images[i][j]
|
| 154 |
+
|
| 155 |
+
if normalize:
|
| 156 |
+
image = normalize_image(image)
|
| 157 |
+
|
| 158 |
+
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters)+j+1)
|
| 159 |
+
ax.imshow(image.numpy(), cmap = 'bone')
|
| 160 |
+
ax.set_title(f'Filter {j+1}')
|
| 161 |
+
ax.axis('off');
|
| 162 |
+
|
| 163 |
+
fig.subplots_adjust(hspace = -0.7)
|
| 164 |
+
plt.savefig(fig_name)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def plot_filters(fig_name, filters, normalize=True):
|
| 168 |
+
filters = filters.cpu()
|
| 169 |
+
|
| 170 |
+
n_filters = filters.shape[0]
|
| 171 |
+
|
| 172 |
+
rows = int(np.sqrt(n_filters))
|
| 173 |
+
cols = int(np.sqrt(n_filters))
|
| 174 |
+
|
| 175 |
+
fig = plt.figure(figsize=(30, 15))
|
| 176 |
+
|
| 177 |
+
for i in range(rows * cols):
|
| 178 |
+
|
| 179 |
+
image = filters[i]
|
| 180 |
+
|
| 181 |
+
if normalize:
|
| 182 |
+
image = normalize_image(image)
|
| 183 |
+
|
| 184 |
+
ax = fig.add_subplot(rows, cols, i + 1)
|
| 185 |
+
ax.imshow(image.permute(1, 2, 0))
|
| 186 |
+
ax.axis('off')
|
| 187 |
+
|
| 188 |
+
fig.subplots_adjust(wspace=-0.9)
|
| 189 |
+
plt.savefig(fig_name)
|
| 190 |
+
|
| 191 |
+
def plot_tsne(fig_name, all_features, all_labels):
|
| 192 |
+
tsne = TSNE(n_components=2, random_state=42)
|
| 193 |
+
tsne_results = tsne.fit_transform(all_features)
|
| 194 |
+
plt.figure(figsize=(10, 7))
|
| 195 |
+
scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=all_labels, cmap='viridis', s=5)
|
| 196 |
+
plt.colorbar(scatter)
|
| 197 |
+
plt.title('t-SNE Visualization')
|
| 198 |
+
plt.show()
|
| 199 |
+
plt.savefig(fig_name)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def plot_grad_cam(images, cams, predicted_labels, true_labels, classes, path):
|
| 203 |
+
fig, axs = plt.subplots(nrows=2, ncols=len(images), figsize=(20, 10))
|
| 204 |
+
|
| 205 |
+
for i, (img, cam, pred_label, true_label) in enumerate(zip(images, cams, predicted_labels, true_labels)):
|
| 206 |
+
# Display the original image on the top row
|
| 207 |
+
axs[0, i].imshow(img.permute(1,2,0).cpu().numpy())
|
| 208 |
+
pred_class_name = classes[pred_label]
|
| 209 |
+
true_class_name = classes[true_label]
|
| 210 |
+
axs[0, i].set_title(f"Predicted: {pred_class_name}\nTrue: {true_class_name}", fontsize=12)
|
| 211 |
+
axs[0, i].axis('off')
|
| 212 |
+
|
| 213 |
+
# Add label to the leftmost plot
|
| 214 |
+
if i == 0:
|
| 215 |
+
axs[0, i].set_ylabel("Original Image", fontsize=14, rotation=90, labelpad=10)
|
| 216 |
+
|
| 217 |
+
# Convert the original image to grayscale
|
| 218 |
+
grayscale_img = cv2.cvtColor(img.permute(1,2,0).cpu().numpy(), cv2.COLOR_RGB2GRAY)
|
| 219 |
+
grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
|
| 220 |
+
|
| 221 |
+
# Overlay the Grad-CAM heatmap on the grayscale image
|
| 222 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
|
| 223 |
+
heatmap = np.float32(heatmap) / 255
|
| 224 |
+
cam_img = heatmap + np.float32(grayscale_img)
|
| 225 |
+
cam_img = cam_img / np.max(cam_img)
|
| 226 |
+
|
| 227 |
+
# Display the Grad-CAM image on the bottom row
|
| 228 |
+
axs[1, i].imshow(cam_img)
|
| 229 |
+
axs[1, i].axis('off')
|
| 230 |
+
|
| 231 |
+
# Add label to the leftmost plot
|
| 232 |
+
if i == 0:
|
| 233 |
+
axs[1, i].set_ylabel("Grad-CAM", fontsize=14, rotation=90, labelpad=10)
|
| 234 |
+
|
| 235 |
+
plt.tight_layout()
|
| 236 |
+
plt.savefig(path)
|
| 237 |
+
plt.close()
|
| 238 |
+
|
vis_confusion_mtx.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from data_utils.data_tribology import TribologyDataset
|
| 7 |
+
from utils.experiment_utils import get_model, get_prediction
|
| 8 |
+
from utils.arg_utils import get_args
|
| 9 |
+
from utils.visualization_utils import plot_confusion_matrix
|
| 10 |
+
|
| 11 |
+
def generate_confusion_matrix(image_name, model, iterator, device):
|
| 12 |
+
labels, predictions = get_prediction(model, iterator, device)
|
| 13 |
+
plot_confusion_matrix('visualization_results/'+image_name+'_confusion_mtx.png', labels, predictions, classes=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY","SPRUCEWOOD"])
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main(args):
|
| 17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
+
|
| 19 |
+
model = get_model(args)
|
| 20 |
+
|
| 21 |
+
basepath=os.getcwd()
|
| 22 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
| 23 |
+
if args.model == 'ViT':
|
| 24 |
+
experiment_dir = Path(os.path.join(basepath,'experiments','ViT_H',args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
| 25 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
| 26 |
+
checkpoint_path = checkpoint_dir / f'epoch{str(args.epochs)}.pth'
|
| 27 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
| 28 |
+
model = model.to(device)
|
| 29 |
+
|
| 30 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
| 31 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
| 32 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
| 33 |
+
BATCHSIZE = args.batch_size
|
| 34 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
| 35 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
| 36 |
+
|
| 37 |
+
means, stds = train_dataset.get_statistics()
|
| 38 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
| 39 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
| 40 |
+
|
| 41 |
+
test_iterator = torch.utils.data.DataLoader(test_dataset,
|
| 42 |
+
batch_size=BATCHSIZE,
|
| 43 |
+
num_workers=4,
|
| 44 |
+
shuffle=False,
|
| 45 |
+
pin_memory=True,
|
| 46 |
+
drop_last=False)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
generate_confusion_matrix(args.model, model, test_iterator, device)
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
args = get_args()
|
| 53 |
+
main(args)
|
| 54 |
+
|
vote_analysis.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
import torch.optim.lr_scheduler as lr_scheduler
|
| 7 |
+
|
| 8 |
+
import torch.utils.data as data
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import random
|
| 12 |
+
import tqdm
|
| 13 |
+
import os
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
from data_utils.data_tribology import TribologyDataset
|
| 17 |
+
from utils.experiment_utils import get_model, get_name, get_logger, train, evaluate, evaluate_vote, evaluate_vote_analysis
|
| 18 |
+
from utils.arg_utils import get_args
|
| 19 |
+
|
| 20 |
+
def main(args):
|
| 21 |
+
'''Reproducibility'''
|
| 22 |
+
SEED = args.seed
|
| 23 |
+
random.seed(SEED)
|
| 24 |
+
np.random.seed(SEED)
|
| 25 |
+
torch.manual_seed(SEED)
|
| 26 |
+
torch.cuda.manual_seed(SEED)
|
| 27 |
+
torch.backends.cudnn.deterministic = True
|
| 28 |
+
torch.backends.cudnn.benchmark = False
|
| 29 |
+
|
| 30 |
+
'''Folder Creation'''
|
| 31 |
+
basepath=os.getcwd()
|
| 32 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
| 33 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
| 35 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
'''Logging'''
|
| 38 |
+
model_name = get_name(args)
|
| 39 |
+
print(model_name, 'STARTED')
|
| 40 |
+
|
| 41 |
+
logger = get_logger(experiment_dir, 'vote_analysis')
|
| 42 |
+
|
| 43 |
+
'''Data Loading'''
|
| 44 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
| 45 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
| 46 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
| 47 |
+
|
| 48 |
+
# results_acc_1 = {}
|
| 49 |
+
# results_acc_3 = {}
|
| 50 |
+
# classes_num = 6
|
| 51 |
+
BATCHSIZE = args.batch_size
|
| 52 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
| 53 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
| 54 |
+
|
| 55 |
+
# prepare the data augmentation
|
| 56 |
+
means, stds = train_dataset.get_statistics()
|
| 57 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
| 58 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
| 59 |
+
|
| 60 |
+
VALID_RATIO = 0.1
|
| 61 |
+
|
| 62 |
+
num_train = len(train_dataset)
|
| 63 |
+
num_valid = int(VALID_RATIO * num_train)
|
| 64 |
+
train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
|
| 65 |
+
logger.info(f'Number of training samples: {len(train_dataset)}')
|
| 66 |
+
logger.info(f'Number of validation samples: {len(valid_dataset)}')
|
| 67 |
+
|
| 68 |
+
test_iterator = torch.utils.data.DataLoader(test_dataset,
|
| 69 |
+
batch_size=BATCHSIZE,
|
| 70 |
+
num_workers=4,
|
| 71 |
+
shuffle=False,
|
| 72 |
+
pin_memory=True,
|
| 73 |
+
drop_last=False)
|
| 74 |
+
print('DATA LOADED')
|
| 75 |
+
|
| 76 |
+
# Define model
|
| 77 |
+
model = get_model(args)
|
| 78 |
+
print('MODEL LOADED')
|
| 79 |
+
|
| 80 |
+
# Define device
|
| 81 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 82 |
+
model = model.to(device)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
print('SETUP DONE')
|
| 86 |
+
# train our model
|
| 87 |
+
|
| 88 |
+
print('TRAINING STARTED')
|
| 89 |
+
|
| 90 |
+
model.load_state_dict(torch.load(checkpoint_dir / f'epoch{args.epochs}.pth'))
|
| 91 |
+
logger.info('-------------------Beginning of Testing-------------------')
|
| 92 |
+
print('TESTING STARTED')
|
| 93 |
+
|
| 94 |
+
vote_accuracy, correct_case_accuracy, incorrect_case_accuracy, incorrect_most_common, novote_accuracy = evaluate_vote_analysis(model, test_iterator, device)
|
| 95 |
+
logger.info(f'Test Acc @1: {vote_accuracy * 100:6.2f}%')
|
| 96 |
+
logger.info(f'No Vote Accuracy @1: {novote_accuracy * 100:6.2f}%')
|
| 97 |
+
logger.info(f'Correct Case Consistency @1: {correct_case_accuracy * 100:6.2f}%')
|
| 98 |
+
logger.info(f'Incorrect Case Consistency @1: {incorrect_case_accuracy * 100:6.2f}%')
|
| 99 |
+
logger.info(f'Incorrect Most Common: {incorrect_most_common* 100:6.2f}%')
|
| 100 |
+
|
| 101 |
+
logger.info('-------------------End of Testing-------------------')
|
| 102 |
+
print('TESTING DONE')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == '__main__':
|
| 106 |
+
args = get_args()
|
| 107 |
+
main(args)
|