Santipab's picture
Upload 30 files
a19d827 verified
import os
import torch.utils.data as data
import pre_proc
import cv2
from scipy.io import loadmat
import numpy as np
def rearrange_pts(pts):
boxes = []
for k in range(0, len(pts), 4):
pts_4 = pts[k:k+4,:]
x_inds = np.argsort(pts_4[:, 0])
pt_l = np.asarray(pts_4[x_inds[:2], :])
pt_r = np.asarray(pts_4[x_inds[2:], :])
y_inds_l = np.argsort(pt_l[:,1])
y_inds_r = np.argsort(pt_r[:,1])
tl = pt_l[y_inds_l[0], :]
bl = pt_l[y_inds_l[1], :]
tr = pt_r[y_inds_r[0], :]
br = pt_r[y_inds_r[1], :]
# boxes.append([tl, tr, bl, br])
boxes.append(tl)
boxes.append(tr)
boxes.append(bl)
boxes.append(br)
return np.asarray(boxes, np.float32)
class BaseDataset(data.Dataset):
def __init__(self, data_dir, phase, input_h=None, input_w=None, down_ratio=4):
super(BaseDataset, self).__init__()
self.data_dir = data_dir
self.phase = phase
self.input_h = input_h
self.input_w = input_w
self.down_ratio = down_ratio
self.class_name = ['__background__', 'cell']
self.num_classes = 68
self.img_dir = os.path.join(data_dir, 'data', self.phase)
self.img_ids = sorted(os.listdir(self.img_dir))
def load_image(self, index):
image = cv2.imread(os.path.join(self.img_dir, self.img_ids[index]))
return image
def load_gt_pts(self, annopath):
pts = loadmat(annopath)['p2'] # num x 2 (x,y)
pts = rearrange_pts(pts)
return pts
def load_annoFolder(self, img_id):
return os.path.join(self.data_dir, 'labels', self.phase, img_id+'.mat')
def load_annotation(self, index):
img_id = self.img_ids[index]
annoFolder = self.load_annoFolder(img_id)
pts = self.load_gt_pts(annoFolder)
return pts
def __getitem__(self, index):
img_id = self.img_ids[index]
image = self.load_image(index)
if self.phase == 'test':
images = pre_proc.processing_test(image=image, input_h=self.input_h, input_w=self.input_w)
return {'images': images, 'img_id': img_id}
else:
aug_label = False
if self.phase == 'train':
aug_label = True
pts = self.load_annotation(index) # num_obj x h x w
out_image, pts_2 = pre_proc.processing_train(image=image,
pts=pts,
image_h=self.input_h,
image_w=self.input_w,
down_ratio=self.down_ratio,
aug_label=aug_label,
img_id=img_id)
data_dict = pre_proc.generate_ground_truth(image=out_image,
pts_2=pts_2,
image_h=self.input_h//self.down_ratio,
image_w=self.input_w//self.down_ratio,
img_id=img_id)
return data_dict
def __len__(self):
return len(self.img_ids)