|
import os
|
|
import torch.utils.data as data
|
|
import pre_proc
|
|
import cv2
|
|
from scipy.io import loadmat
|
|
import numpy as np
|
|
|
|
|
|
def rearrange_pts(pts):
|
|
boxes = []
|
|
for k in range(0, len(pts), 4):
|
|
pts_4 = pts[k:k+4,:]
|
|
x_inds = np.argsort(pts_4[:, 0])
|
|
pt_l = np.asarray(pts_4[x_inds[:2], :])
|
|
pt_r = np.asarray(pts_4[x_inds[2:], :])
|
|
y_inds_l = np.argsort(pt_l[:,1])
|
|
y_inds_r = np.argsort(pt_r[:,1])
|
|
tl = pt_l[y_inds_l[0], :]
|
|
bl = pt_l[y_inds_l[1], :]
|
|
tr = pt_r[y_inds_r[0], :]
|
|
br = pt_r[y_inds_r[1], :]
|
|
|
|
boxes.append(tl)
|
|
boxes.append(tr)
|
|
boxes.append(bl)
|
|
boxes.append(br)
|
|
return np.asarray(boxes, np.float32)
|
|
|
|
|
|
class BaseDataset(data.Dataset):
|
|
def __init__(self, data_dir, phase, input_h=None, input_w=None, down_ratio=4):
|
|
super(BaseDataset, self).__init__()
|
|
self.data_dir = data_dir
|
|
self.phase = phase
|
|
self.input_h = input_h
|
|
self.input_w = input_w
|
|
self.down_ratio = down_ratio
|
|
self.class_name = ['__background__', 'cell']
|
|
self.num_classes = 68
|
|
self.img_dir = os.path.join(data_dir, 'data', self.phase)
|
|
self.img_ids = sorted(os.listdir(self.img_dir))
|
|
|
|
def load_image(self, index):
|
|
image = cv2.imread(os.path.join(self.img_dir, self.img_ids[index]))
|
|
return image
|
|
|
|
def load_gt_pts(self, annopath):
|
|
pts = loadmat(annopath)['p2']
|
|
pts = rearrange_pts(pts)
|
|
return pts
|
|
|
|
def load_annoFolder(self, img_id):
|
|
return os.path.join(self.data_dir, 'labels', self.phase, img_id+'.mat')
|
|
|
|
def load_annotation(self, index):
|
|
img_id = self.img_ids[index]
|
|
annoFolder = self.load_annoFolder(img_id)
|
|
pts = self.load_gt_pts(annoFolder)
|
|
return pts
|
|
|
|
def __getitem__(self, index):
|
|
img_id = self.img_ids[index]
|
|
image = self.load_image(index)
|
|
if self.phase == 'test':
|
|
images = pre_proc.processing_test(image=image, input_h=self.input_h, input_w=self.input_w)
|
|
return {'images': images, 'img_id': img_id}
|
|
else:
|
|
aug_label = False
|
|
if self.phase == 'train':
|
|
aug_label = True
|
|
pts = self.load_annotation(index)
|
|
out_image, pts_2 = pre_proc.processing_train(image=image,
|
|
pts=pts,
|
|
image_h=self.input_h,
|
|
image_w=self.input_w,
|
|
down_ratio=self.down_ratio,
|
|
aug_label=aug_label,
|
|
img_id=img_id)
|
|
|
|
data_dict = pre_proc.generate_ground_truth(image=out_image,
|
|
pts_2=pts_2,
|
|
image_h=self.input_h//self.down_ratio,
|
|
image_w=self.input_w//self.down_ratio,
|
|
img_id=img_id)
|
|
return data_dict
|
|
|
|
def __len__(self):
|
|
return len(self.img_ids)
|
|
|