|
import torch |
|
import torchvision.transforms as tvt |
|
import pandas as pd |
|
import os |
|
from tqdm import tqdm |
|
from PIL import Image |
|
|
|
torch.set_num_threads(2) |
|
|
|
outdir = 'pt_files/train' |
|
yolo_crop_file = 'image_yolo.txt' |
|
|
|
def crop(img, x, y, w, h): |
|
|
|
|
|
|
|
|
|
W, H = img.size |
|
x1 = x * W - w * W / 2.0 |
|
x2 = x * W + w * W / 2.0 |
|
y1 = y * H - h * H / 2.0 |
|
y2 = y * H + h * H / 2.0 |
|
|
|
return img.crop((x1,y1,x2,y2)) |
|
|
|
is_report_file = lambda s: 'RPT' in s |
|
get_barcode = lambda s: s.split('/')[-3] |
|
|
|
CHANNEL = 3 |
|
IMAGE_SIZE = 448 |
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
normalize = tvt.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)) |
|
transform_ops = tvt.Compose([tvt.Resize(IMAGE_SIZE), tvt.CenterCrop(IMAGE_SIZE), tvt.ToTensor(), normalize]) |
|
|
|
model_path = './traced_swav_imagenet_layer2.pt' |
|
|
|
df = pd.read_csv(yolo_crop_file) |
|
df.insert(0, 'is_report_file', [is_report_file(s) for s in df.orig]) |
|
df.insert(0, 'patient_barcode', [get_barcode(s) for s in df.orig]) |
|
df = df[df.is_report_file == False] |
|
|
|
net = torch.jit.load(model_path) |
|
net = net.cuda() |
|
net.eval() |
|
|
|
|
|
for patient_barcode, dfg in tqdm(df.groupby('patient_barcode'), total=len(df.patient_barcode.unique())): |
|
outfile = f"{outdir}/{patient_barcode}.pt" |
|
if os.path.exists(outfile):continue |
|
|
|
N = len(dfg) |
|
image_tensors = torch.zeros(N, CHANNEL, IMAGE_SIZE, IMAGE_SIZE) |
|
for i, image_file, x, y, w, h in zip(range(N), dfg.orig, dfg.x, dfg.y, dfg.w, dfg.h): |
|
with open(image_file, 'rb') as f: |
|
img = Image.open(f) |
|
img = img.convert('RGB') |
|
img = crop(img, x, y, w, h) |
|
img_tensor = transform_ops(img) |
|
image_tensors[i] = img_tensor |
|
|
|
image_tensors = image_tensors.cuda() |
|
with torch.no_grad(): |
|
features = net(image_tensors).cpu() |
|
|
|
torch.save(features, outfile) |
|
|
|
|