from pathlib import Path
import scipy.io as sio
from PIL import Image
import numpy as np
from src.utils.config import cfg
from src.dataset.base_dataset import BaseDataset
import random
'''
Important Notice: Face image 160 contains only 8 labeled keypoints (should be 10)
'''
[docs]class WillowObject(BaseDataset):
def __init__(self, sets, obj_resize):
"""
:param sets: 'train' or 'test'
:param obj_resize: resized object size
"""
super(WillowObject, self).__init__()
self.classes = cfg.WillowObject.CLASSES
self.kpt_len = [cfg.WillowObject.KPT_LEN for _ in cfg.WillowObject.CLASSES]
self.root_path = Path(cfg.WillowObject.ROOT_DIR)
self.obj_resize = obj_resize
assert sets in ('train', 'test'), 'No match found for dataset {}'.format(sets)
self.sets = sets
self.split_offset = cfg.WillowObject.SPLIT_OFFSET
self.train_len = cfg.WillowObject.TRAIN_NUM
self.rand_outlier = cfg.WillowObject.RAND_OUTLIER
self.mat_list = []
for cls_name in self.classes:
assert type(cls_name) is str
cls_mat_list = [p for p in (self.root_path / cls_name).glob('*.mat')]
if cls_name == 'Face':
cls_mat_list.remove(self.root_path / cls_name / 'image_0160.mat')
assert not self.root_path / cls_name / 'image_0160.mat' in cls_mat_list
ori_len = len(cls_mat_list)
if self.split_offset % ori_len + self.train_len <= ori_len:
if sets == 'train' and not cfg.WillowObject.TRAIN_SAME_AS_TEST:
self.mat_list.append(
cls_mat_list[self.split_offset % ori_len: (self.split_offset + self.train_len) % ori_len]
)
else:
self.mat_list.append(
cls_mat_list[:self.split_offset % ori_len] +
cls_mat_list[(self.split_offset + self.train_len) % ori_len:]
)
else:
if sets == 'train' and not cfg.WillowObject.TRAIN_SAME_AS_TEST:
self.mat_list.append(
cls_mat_list[:(self.split_offset + self.train_len) % ori_len - ori_len] +
cls_mat_list[self.split_offset % ori_len:]
)
else:
self.mat_list.append(
cls_mat_list[(self.split_offset + self.train_len) % ori_len - ori_len: self.split_offset % ori_len]
)
[docs] def get_pair(self, cls=None, shuffle=True):
"""
Randomly get a pair of objects from WILLOW-object dataset
:param cls: None for random class, or specify for a certain set
:param shuffle: random shuffle the keypoints
:return: (pair of data, groundtruth permutation matrix)
"""
if cls is None:
cls = random.randrange(0, len(self.classes))
elif type(cls) == str:
cls = self.classes.index(cls)
assert type(cls) == int and 0 <= cls < len(self.classes)
anno_pair = []
for mat_name in random.sample(self.mat_list[cls], 2):
anno_dict = self.__get_anno_dict(mat_name, cls)
if shuffle:
random.shuffle(anno_dict['keypoints'])
anno_pair.append(anno_dict)
perm_mat = np.zeros([len(_['keypoints']) for _ in anno_pair], dtype=np.float32)
row_list = []
col_list = []
for i, keypoint in enumerate(anno_pair[0]['keypoints']):
for j, _keypoint in enumerate(anno_pair[1]['keypoints']):
if keypoint['name'] == _keypoint['name']:
if keypoint['name'] != 'outlier':
perm_mat[i, j] = 1
row_list.append(i)
col_list.append(j)
break
row_list.sort()
col_list.sort()
perm_mat = perm_mat[row_list, :]
perm_mat = perm_mat[:, col_list]
anno_pair[0]['keypoints'] = [anno_pair[0]['keypoints'][i] for i in row_list]
anno_pair[1]['keypoints'] = [anno_pair[1]['keypoints'][j] for j in col_list]
return anno_pair, perm_mat
[docs] def get_multi(self, cls=None, num=2, shuffle=True):
"""
Randomly get multiple objects from Willow Object Class dataset for multi-matching.
:param cls: None for random class, or specify for a certain set
:param num: number of objects to be fetched
:param shuffle: random shuffle the keypoints
:return: (list of data, list of permutation matrices)
"""
if cls is None:
cls = random.randrange(0, len(self.classes))
elif type(cls) == str:
cls = self.classes.index(cls)
assert type(cls) == int and 0 <= cls < len(self.classes)
anno_list = []
for mat_name in random.sample(self.mat_list[cls], num):
anno_dict = self.__get_anno_dict(mat_name, cls)
if shuffle:
random.shuffle(anno_dict['keypoints'])
anno_list.append(anno_dict)
perm_mat = [np.zeros([len(anno_list[0]['keypoints']), len(x['keypoints'])], dtype=np.float32) for x in
anno_list]
row_list = []
col_lists = []
for i in range(num):
col_lists.append([])
for i, keypoint in enumerate(anno_list[0]['keypoints']):
kpt_idx = []
for anno_dict in anno_list:
kpt_name_list = [x['name'] for x in anno_dict['keypoints']]
if keypoint['name'] in kpt_name_list:
kpt_idx.append(kpt_name_list.index(keypoint['name']))
else:
kpt_idx.append(-1)
row_list.append(i)
for k in range(num):
j = kpt_idx[k]
if j != -1:
col_lists[k].append(j)
if keypoint['name'] != 'outlier':
perm_mat[k][i, j] = 1
row_list.sort()
for col_list in col_lists:
col_list.sort()
for k in range(num):
perm_mat[k] = perm_mat[k][row_list, :]
perm_mat[k] = perm_mat[k][:, col_lists[k]]
anno_list[k]['keypoints'] = [anno_list[k]['keypoints'][j] for j in col_lists[k]]
perm_mat[k] = perm_mat[k].transpose()
return anno_list, perm_mat
def __get_anno_dict(self, mat_file, cls):
"""
Get an annotation dict from .mat annotation
"""
assert mat_file.exists(), '{} does not exist.'.format(mat_file)
img_name = mat_file.stem + '.png'
img_file = mat_file.parent / img_name
struct = sio.loadmat(mat_file.open('rb'))
kpts = struct['pts_coord']
with Image.open(str(img_file)) as img:
ori_sizes = img.size
obj = img.resize(self.obj_resize, resample=Image.BICUBIC)
xmin = 0
ymin = 0
w = ori_sizes[0]
h = ori_sizes[1]
keypoint_list = []
for idx, keypoint in enumerate(np.split(kpts, kpts.shape[1], axis=1)):
attr = {
'name': idx,
'x': float(keypoint[0]) * self.obj_resize[0] / w,
'y': float(keypoint[1]) * self.obj_resize[1] / h
}
keypoint_list.append(attr)
for idx in range(self.rand_outlier):
attr = {
'name': 'outlier',
'x': random.uniform(0, self.obj_resize[0]),
'y': random.uniform(0, self.obj_resize[1])
}
keypoint_list.append(attr)
anno_dict = dict()
anno_dict['image'] = obj
anno_dict['keypoints'] = keypoint_list
anno_dict['bounds'] = xmin, ymin, w, h
anno_dict['ori_sizes'] = ori_sizes
anno_dict['cls'] = cls
anno_dict['univ_size'] = 10
return anno_dict
[docs] def len(self, cls):
if type(cls) == int:
cls = self.classes[cls]
assert cls in self.classes
return len(self.mat_list[self.classes.index(cls)])
if __name__ == '__main__':
cfg.WillowObject.ROOT_DIR = 'WILLOW-ObjectClass'
cfg.WillowObject.SPLIT_OFFSET = 0
train = WillowObject('train', (256, 256))
test = WillowObject('test', (256, 256))
for train_cls_list, test_cls_list in zip(train.mat_list, test.mat_list):
for t in train_cls_list:
assert t not in test_cls_list
pass