imagenet-swav-resnet50w2 / featureExtractor.py
lixc
extract features by SwAV-ResNet50w2 model
d9e5c87
import torch
import torchvision.transforms as tvt
import pandas as pd
import os
from tqdm import tqdm
from PIL import Image
torch.set_num_threads(2)
outdir = 'pt_files/train'
yolo_crop_file = 'image_yolo.txt'
def crop(img, x, y, w, h):
#if not dets:
# return img
#x, y, w, h = [float(e) for e in dets.split(',')[0:4]]
W, H = img.size
x1 = x * W - w * W / 2.0
x2 = x * W + w * W / 2.0
y1 = y * H - h * H / 2.0
y2 = y * H + h * H / 2.0
return img.crop((x1,y1,x2,y2))
is_report_file = lambda s: 'RPT' in s
get_barcode = lambda s: s.split('/')[-3]
CHANNEL = 3
IMAGE_SIZE = 448
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
normalize = tvt.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD))
transform_ops = tvt.Compose([tvt.Resize(IMAGE_SIZE), tvt.CenterCrop(IMAGE_SIZE), tvt.ToTensor(), normalize])
model_path = './traced_swav_imagenet_layer2.pt'
df = pd.read_csv(yolo_crop_file)
df.insert(0, 'is_report_file', [is_report_file(s) for s in df.orig])
df.insert(0, 'patient_barcode', [get_barcode(s) for s in df.orig])
df = df[df.is_report_file == False]
net = torch.jit.load(model_path)
net = net.cuda()
net.eval()
for patient_barcode, dfg in tqdm(df.groupby('patient_barcode'), total=len(df.patient_barcode.unique())):
outfile = f"{outdir}/{patient_barcode}.pt"
if os.path.exists(outfile):continue
N = len(dfg)
image_tensors = torch.zeros(N, CHANNEL, IMAGE_SIZE, IMAGE_SIZE)
for i, image_file, x, y, w, h in zip(range(N), dfg.orig, dfg.x, dfg.y, dfg.w, dfg.h):
with open(image_file, 'rb') as f:
img = Image.open(f)
img = img.convert('RGB')
img = crop(img, x, y, w, h)
img_tensor = transform_ops(img)
image_tensors[i] = img_tensor
image_tensors = image_tensors.cuda()
with torch.no_grad():
features = net(image_tensors).cpu()
torch.save(features, outfile)