Source code for src.dataset.imc_pt_sparsegm

from pathlib import Path
from PIL import Image
import numpy as np
from src.utils.config import cfg
from src.dataset.base_dataset import BaseDataset
import random


[docs]class IMC_PT_SparseGM(BaseDataset): def __init__(self, sets, obj_resize): """ :param sets: 'train' or 'test' :param obj_resize: resized object size """ super(IMC_PT_SparseGM, self).__init__() assert sets in ('train', 'test'), 'No match found for dataset {}'.format(sets) self.sets = sets self.classes = cfg.IMC_PT_SparseGM.CLASSES[sets] self.total_kpt_num = cfg.IMC_PT_SparseGM.TOTAL_KPT_NUM self.root_path_npz = Path(cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ) self.root_path_img = Path(cfg.IMC_PT_SparseGM.ROOT_DIR_IMG) self.obj_resize = obj_resize self.img_lists = [np.load(self.root_path_npz / cls / 'img_info.npz')['img_name'].tolist() for cls in self.classes]
[docs] def get_pair(self, cls=None, shuffle=True, tgt_outlier=False, src_outlier=False): """ Randomly get a pair of objects from Photo Tourism dataset :param cls: None for random class, or specify for a certain set :param shuffle: random shuffle the keypoints :param src_outlier: allow outlier in the source graph (first graph) :param tgt_outlier: allow outlier in the target graph (second graph) :return: (pair of data, groundtruth permutation matrix) """ if cls is None: cls = random.randrange(0, len(self.classes)) elif type(cls) == str: cls = self.classes.index(cls) assert type(cls) == int and 0 <= cls < len(self.classes) anno_pair = [] for img_name in random.sample(self.img_lists[cls], 2): anno_dict = self.__get_anno_dict(img_name, cls) if shuffle: random.shuffle(anno_dict['keypoints']) anno_pair.append(anno_dict) perm_mat = np.zeros([len(_['keypoints']) for _ in anno_pair], dtype=np.float32) row_list = [] col_list = [] for i, keypoint in enumerate(anno_pair[0]['keypoints']): for j, _keypoint in enumerate(anno_pair[1]['keypoints']): if keypoint['name'] == _keypoint['name']: if keypoint['name'] != 'outlier': perm_mat[i, j] = 1 row_list.append(i) col_list.append(j) break row_list.sort() col_list.sort() if not src_outlier: perm_mat = perm_mat[row_list, :] anno_pair[0]['keypoints'] = [anno_pair[0]['keypoints'][i] for i in row_list] if not tgt_outlier: perm_mat = perm_mat[:, col_list] anno_pair[1]['keypoints'] = [anno_pair[1]['keypoints'][j] for j in col_list] return anno_pair, perm_mat
[docs] def get_multi(self, cls=None, num=2, shuffle=True, filter_outlier=False): """ Randomly get multiple objects from Willow Object Class dataset for multi-matching. :param cls: None for random class, or specify for a certain set :param num: number of objects to be fetched :param shuffle: random shuffle the keypoints :param filter_outlier: filter out outlier keypoints among images :return: (list of data, list of permutation matrices) """ assert not filter_outlier, 'Multi-matching on IMC_PT_SparseGM dataset with filtered outliers is not supported' if cls is None: cls = random.randrange(0, len(self.classes)) elif type(cls) == str: cls = self.classes.index(cls) assert type(cls) == int and 0 <= cls < len(self.classes) anno_list = [] for img_name in random.sample(self.img_lists[cls], num): anno_dict = self.__get_anno_dict(img_name, cls) if shuffle: random.shuffle(anno_dict['keypoints']) anno_list.append(anno_dict) perm_mat = [np.zeros([self.total_kpt_num, len(x['keypoints'])], dtype=np.float32) for x in anno_list] for k, anno_dict in enumerate(anno_list): kpt_name_list = [x['name'] for x in anno_dict['keypoints']] for j, name in enumerate(kpt_name_list): perm_mat[k][name, j] = 1 for k in range(num): perm_mat[k] = perm_mat[k].transpose() return anno_list, perm_mat
def __get_anno_dict(self, img_name, cls): """ Get an annotation dict from .npz annotation """ img_file = self.root_path_img / self.classes[cls] / 'dense' / 'images' / img_name npz_file = self.root_path_npz / self.classes[cls] / (img_name + '.npz') assert img_file.exists(), '{} does not exist.'.format(img_file) assert npz_file.exists(), '{} does not exist.'.format(npz_file) with Image.open(str(img_file)) as img: ori_sizes = img.size obj = img.resize(self.obj_resize, resample=Image.BICUBIC) xmin = 0 ymin = 0 w = ori_sizes[0] h = ori_sizes[1] with np.load(str(npz_file)) as npz_anno: kpts = npz_anno['points'] if len(kpts.shape) != 2: ValueError('{} contains no keypoints.'.format(img_file)) keypoint_list = [] for i in range(kpts.shape[1]): kpt_index = int(kpts[0, i]) assert kpt_index < self.total_kpt_num attr = { 'name': kpt_index, 'x': kpts[1, i] * self.obj_resize[0] / w, 'y': kpts[2, i] * self.obj_resize[1] / h } keypoint_list.append(attr) anno_dict = dict() anno_dict['image'] = obj anno_dict['keypoints'] = keypoint_list anno_dict['bounds'] = xmin, ymin, w, h anno_dict['ori_sizes'] = ori_sizes anno_dict['cls'] = self.classes[cls] anno_dict['univ_size'] = self.total_kpt_num return anno_dict
[docs] def len(self, cls): if type(cls) == int: cls = self.classes[cls] assert cls in self.classes return len(self.img_lists[self.classes.index(cls)])