import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import torch_geometric as pyg
import numpy as np
import random
from src.build_graphs import build_graphs
from src.factorize_graph_matching import kronecker_sparse, kronecker_torch
from src.sparse_torch import CSRMatrix3d
from src.dataset import *
from src.utils.config import cfg
from itertools import combinations, product
[docs]class GMDataset(Dataset):
def __init__(self, name, bm, length, using_all_graphs=False, cls=None, problem='2GM'):
self.name = name
self.bm = bm
self.using_all_graphs = using_all_graphs
self.obj_size = self.bm.obj_resize
self.test = True if self.bm.sets == 'test' else False
self.cls = None if cls in ['none', 'all'] else cls
if self.cls is None:
if problem == 'MGM3':
self.classes = list(combinations(self.bm.classes, cfg.PROBLEM.NUM_CLUSTERS))
else:
self.classes = self.bm.classes
else:
self.classes = [self.cls]
self.problem_type = problem
if problem != 'MGM3':
self.img_num_list = self.bm.compute_img_num(self.classes)
else:
self.img_num_list = self.bm.compute_img_num(self.classes[0])
if self.problem_type == '2GM':
self.id_combination, self.length = self.bm.get_id_combination(self.cls)
self.length_list = []
for cls in self.classes:
cls_length = self.bm.compute_length(cls)
self.length_list.append(cls_length)
else:
self.length = length
def __len__(self):
return self.length
def __getitem__(self, idx):
if self.problem_type == '2GM':
return self.get_pair(idx, self.cls)
elif self.problem_type == 'MGM':
return self.get_multi(idx, self.cls)
elif self.problem_type == 'MGM3':
return self.get_multi_cluster(idx)
else:
raise NameError("Unknown problem type: {}".format(self.problem_type))
[docs] @staticmethod
def to_pyg_graph(A, P):
rescale = max(cfg.PROBLEM.RESCALE)
edge_feat = 0.5 * (np.expand_dims(P, axis=1) - np.expand_dims(P, axis=0)) / rescale + 0.5 # from Rolink's paper
edge_index = np.nonzero(A)
edge_attr = edge_feat[edge_index]
edge_attr = np.clip(edge_attr, 0, 1)
assert (edge_attr > -1e-5).all(), P
o3_A = np.expand_dims(A, axis=0) * np.expand_dims(A, axis=1) * np.expand_dims(A, axis=2)
hyperedge_index = np.nonzero(o3_A)
pyg_graph = pyg.data.Data(
x=torch.tensor(P / rescale).to(torch.float32),
edge_index=torch.tensor(np.array(edge_index), dtype=torch.long),
edge_attr=torch.tensor(edge_attr).to(torch.float32),
hyperedge_index=torch.tensor(np.array(hyperedge_index), dtype=torch.long),
)
return pyg_graph
[docs] def get_pair(self, idx, cls):
#anno_pair, perm_mat = self.bm.get_pair(self.cls if self.cls is not None else
# (idx % (cfg.BATCH_SIZE * len(self.classes))) // cfg.BATCH_SIZE)
cls_num = random.randrange(0, len(self.classes))
ids = list(self.id_combination[cls_num][idx % self.length_list[cls_num]])
anno_pair, perm_mat_, id_list = self.bm.get_data(ids)
perm_mat = perm_mat_[(0, 1)].toarray()
while min(perm_mat.shape[0], perm_mat.shape[1]) <= 2 or perm_mat.size >= cfg.PROBLEM.MAX_PROB_SIZE > 0:
anno_pair, perm_mat_, id_list = self.bm.rand_get_data(cls)
perm_mat = perm_mat_[(0, 1)].toarray()
cls = [anno['cls'] for anno in anno_pair]
P1 = [(kp['x'], kp['y']) for kp in anno_pair[0]['kpts']]
P2 = [(kp['x'], kp['y']) for kp in anno_pair[1]['kpts']]
n1, n2 = len(P1), len(P2)
univ_size = [anno['univ_size'] for anno in anno_pair]
P1 = np.array(P1)
P2 = np.array(P2)
A1, G1, H1, e1 = build_graphs(P1, n1, stg=cfg.GRAPH.SRC_GRAPH_CONSTRUCT, sym=cfg.GRAPH.SYM_ADJACENCY)
if cfg.GRAPH.TGT_GRAPH_CONSTRUCT == 'same':
G2 = perm_mat.transpose().dot(G1)
H2 = perm_mat.transpose().dot(H1)
A2 = G2.dot(H2.transpose())
e2 = e1
else:
A2, G2, H2, e2 = build_graphs(P2, n2, stg=cfg.GRAPH.TGT_GRAPH_CONSTRUCT, sym=cfg.GRAPH.SYM_ADJACENCY)
pyg_graph1 = self.to_pyg_graph(A1, P1)
pyg_graph2 = self.to_pyg_graph(A2, P2)
ret_dict = {'Ps': [torch.Tensor(x) for x in [P1, P2]],
'ns': [torch.tensor(x) for x in [n1, n2]],
'es': [torch.tensor(x) for x in [e1, e2]],
'gt_perm_mat': perm_mat,
'Gs': [torch.Tensor(x) for x in [G1, G2]],
'Hs': [torch.Tensor(x) for x in [H1, H2]],
'As': [torch.Tensor(x) for x in [A1, A2]],
'pyg_graphs': [pyg_graph1, pyg_graph2],
'cls': [str(x) for x in cls],
'id_list': id_list,
'univ_size': [torch.tensor(int(x)) for x in univ_size],
}
imgs = [anno['img'] for anno in anno_pair]
if imgs[0] is not None:
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.NORM_MEANS, cfg.NORM_STD)
])
imgs = [trans(img) for img in imgs]
ret_dict['images'] = imgs
elif 'feat' in anno_pair[0]['kpts'][0]:
feat1 = np.stack([kp['feat'] for kp in anno_pair[0]['kpts']], axis=-1)
feat2 = np.stack([kp['feat'] for kp in anno_pair[1]['kpts']], axis=-1)
ret_dict['features'] = [torch.Tensor(x) for x in [feat1, feat2]]
return ret_dict
[docs] def get_multi(self, idx, cls):
if self.problem_type == 'MGM' and self.using_all_graphs:
if cls == None:
cls = random.randrange(0, len(self.classes))
num_graphs = self.img_num_list[cls]
cls = self.classes[cls]
elif type(cls) == str:
cls_num = self.classes.index(cls)
num_graphs = self.img_num_list[cls_num]
else:
num_graphs = self.img_num_list[cls]
cls = self.classes[cls]
elif self.problem_type == 'MGM3' and self.using_all_graphs:
if cls == None:
cls = random.randrange(0, len(self.classes[0]))
num_graphs = self.img_num_list[cls]
cls = self.classes[cls]
elif type(cls) == str:
cls_num = self.classes[0].index(cls)
num_graphs = self.img_num_list[cls_num]
else:
num_graphs = self.img_num_list[cls]
cls = self.classes[cls]
else:
num_graphs = cfg.PROBLEM.NUM_GRAPHS
refetch = True
while refetch:
refetch = False
anno_list, perm_mat_dict, id_list = self.bm.rand_get_data(cls, num=num_graphs)
perm_mat_dict = {key: val.toarray() for key, val in perm_mat_dict.items()}
for pm in perm_mat_dict.values():
if pm.shape[0] <= 2 or pm.shape[1] <= 2 or pm.size >= cfg.PROBLEM.MAX_PROB_SIZE > 0:
refetch = True
break
cls = [anno['cls'] for anno in anno_list]
Ps = [[(kp['x'], kp['y']) for kp in anno_dict['kpts']] for anno_dict in anno_list]
ns = [len(P) for P in Ps]
univ_size = [anno['univ_size'] for anno in anno_list]
Ps = [np.array(P) for P in Ps]
As = []
Gs = []
Hs = []
As_tgt = []
Gs_tgt = []
Hs_tgt = []
for P, n in zip(Ps, ns):
# In multi-graph matching (MGM), when a graph is regarded as target graph, its topology may be different
# from when it is regarded as source graph. These are represented by suffix "tgt".
A, G, H, _ = build_graphs(P, n, stg=cfg.GRAPH.SRC_GRAPH_CONSTRUCT)
A_tgt, G_tgt, H_tgt, _ = build_graphs(P, n, stg=cfg.GRAPH.TGT_GRAPH_CONSTRUCT)
As.append(A)
Gs.append(G)
Hs.append(H)
As_tgt.append(A_tgt)
Gs_tgt.append(G_tgt)
Hs_tgt.append(H_tgt)
pyg_graphs = [self.to_pyg_graph(A, P) for A, P in zip(As, Ps)]
pyg_graphs_tgt = [self.to_pyg_graph(A, P) for A, P in zip(As_tgt, Ps)]
ret_dict = {
'Ps': [torch.Tensor(x) for x in Ps],
'ns': [torch.tensor(x) for x in ns],
'gt_perm_mat': perm_mat_dict,
'Gs': [torch.Tensor(x) for x in Gs],
'Hs': [torch.Tensor(x) for x in Hs],
'As': [torch.Tensor(x) for x in As],
'Gs_tgt': [torch.Tensor(x) for x in Gs_tgt],
'Hs_tgt': [torch.Tensor(x) for x in Hs_tgt],
'As_tgt': [torch.Tensor(x) for x in As_tgt],
'pyg_graphs': pyg_graphs,
'pyg_graphs_tgt': pyg_graphs_tgt,
'cls': [str(x) for x in cls],
'id_list': id_list,
'univ_size': [torch.tensor(int(x)) for x in univ_size],
}
imgs = [anno['img'] for anno in anno_list]
if imgs[0] is not None:
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.NORM_MEANS, cfg.NORM_STD)
])
imgs = [trans(img) for img in imgs]
ret_dict['images'] = imgs
elif 'feat' in anno_list[0]['kpts'][0]:
feats = [np.stack([kp['feat'] for kp in anno_dict['kpts']], axis=-1) for anno_dict in anno_list]
ret_dict['features'] = [torch.Tensor(x) for x in feats]
return ret_dict
[docs] def get_multi_cluster(self, idx):
dicts = []
if self.cls is None or self.cls == 'none':
cls_iterator = random.choice(self.classes)
else:
cls_iterator = self.cls
for cls in cls_iterator:
dicts.append(self.get_multi(idx, cls))
ret_dict = {}
for key in dicts[0]:
ret_dict[key] = []
for dic in dicts:
ret_dict[key] += dic[key]
return ret_dict
[docs]class QAPDataset(Dataset):
def __init__(self, name, length, pad=16, cls=None, **args):
self.name = name
self.ds = eval(self.name)(**args, cls=cls)
self.classes = self.ds.classes
self.cls = None if cls == 'none' else cls
self.length = length
def __len__(self):
#return len(self.ds.data_list)
return self.length
def __getitem__(self, idx):
Fi, Fj, perm_mat, sol, name = self.ds.get_pair(idx % len(self.ds.data_list))
if perm_mat.size <= 2 * 2 or perm_mat.size >= cfg.PROBLEM.MAX_PROB_SIZE > 0:
return self.__getitem__(random.randint(0, len(self) - 1))
#if np.max(ori_aff_mat) > 0:
# norm_aff_mat = ori_aff_mat / np.mean(ori_aff_mat)
#else:
# norm_aff_mat = ori_aff_mat
ret_dict = {'Fi': Fi,
'Fj': Fj,
'gt_perm_mat': perm_mat,
'ns': [torch.tensor(x) for x in perm_mat.shape],
'solution': torch.tensor(sol),
'name': name,
'univ_size': [torch.tensor(x) for x in perm_mat.shape],}
return ret_dict
[docs]def collate_fn(data: list):
"""
Create mini-batch data for training.
:param data: data dict
:return: mini-batch
"""
def pad_tensor(inp):
assert type(inp[0]) == torch.Tensor
it = iter(inp)
t = next(it)
max_shape = list(t.shape)
while True:
try:
t = next(it)
for i in range(len(max_shape)):
max_shape[i] = int(max(max_shape[i], t.shape[i]))
except StopIteration:
break
max_shape = np.array(max_shape)
padded_ts = []
for t in inp:
pad_pattern = np.zeros(2 * len(max_shape), dtype=np.int64)
pad_pattern[::-2] = max_shape - np.array(t.shape)
#pad_pattern = torch.from_numpy(np.asfortranarray(pad_pattern))
pad_pattern = tuple(pad_pattern.tolist())
padded_ts.append(F.pad(t, pad_pattern, 'constant', 0))
return padded_ts
def stack(inp):
if type(inp[0]) == list:
ret = []
for vs in zip(*inp):
ret.append(stack(vs))
elif type(inp[0]) == dict:
ret = {}
for kvs in zip(*[x.items() for x in inp]):
ks, vs = zip(*kvs)
for k in ks:
assert k == ks[0], "Keys mismatch."
ret[k] = stack(vs)
elif type(inp[0]) == torch.Tensor:
new_t = pad_tensor(inp)
ret = torch.stack(new_t, 0)
elif type(inp[0]) == np.ndarray:
new_t = pad_tensor([torch.from_numpy(x) for x in inp])
ret = torch.stack(new_t, 0)
elif type(inp[0]) == pyg.data.Data:
ret = pyg.data.Batch.from_data_list(inp)
elif type(inp[0]) == str:
ret = inp
elif type(inp[0]) == tuple:
ret = inp
else:
raise ValueError('Cannot handle type {}'.format(type(inp[0])))
return ret
ret = stack(data)
# compute CPU-intensive Kronecker product here to leverage multi-processing nature of dataloader
if 'Gs' in ret and 'Hs' in ret:
if cfg.PROBLEM.TYPE == '2GM' and len(ret['Gs']) == 2 and len(ret['Hs']) == 2:
G1, G2 = ret['Gs']
H1, H2 = ret['Hs']
if cfg.FP16:
sparse_dtype = np.float16
else:
sparse_dtype = np.float32
K1G = [kronecker_sparse(x, y).astype(sparse_dtype) for x, y in zip(G2, G1)] # 1 as source graph, 2 as target graph
K1H = [kronecker_sparse(x, y).astype(sparse_dtype) for x, y in zip(H2, H1)]
K1G = CSRMatrix3d(K1G)
K1H = CSRMatrix3d(K1H).transpose()
ret['KGHs'] = K1G, K1H
elif cfg.PROBLEM.TYPE in ['MGM', 'MGM3'] and 'Gs_tgt' in ret and 'Hs_tgt' in ret:
ret['KGHs'] = dict()
for idx_1, idx_2 in product(range(len(ret['Gs'])), repeat=2):
# 1 as source graph, 2 as target graph
G1 = ret['Gs'][idx_1]
H1 = ret['Hs'][idx_1]
G2 = ret['Gs_tgt'][idx_2]
H2 = ret['Hs_tgt'][idx_2]
if cfg.FP16:
sparse_dtype = np.float16
else:
sparse_dtype = np.float32
KG = [kronecker_sparse(x, y).astype(sparse_dtype) for x, y in zip(G2, G1)]
KH = [kronecker_sparse(x, y).astype(sparse_dtype) for x, y in zip(H2, H1)]
KG = CSRMatrix3d(KG)
KH = CSRMatrix3d(KH).transpose()
ret['KGHs']['{},{}'.format(idx_1, idx_2)] = KG, KH
else:
raise ValueError('Data type not understood.')
if 'Fi' in ret and 'Fj' in ret:
Fi = ret['Fi']
Fj = ret['Fj']
aff_mat = kronecker_torch(Fj, Fi)
ret['aff_mat'] = aff_mat
ret['batch_size'] = len(data)
ret['univ_size'] = torch.tensor([max(*[item[b] for item in ret['univ_size']]) for b in range(ret['batch_size'])])
for v in ret.values():
if type(v) is list:
ret['num_graphs'] = len(v)
break
return ret
[docs]def worker_init_fix(worker_id):
"""
Init dataloader workers with fixed seed.
"""
random.seed(cfg.RANDOM_SEED + worker_id)
np.random.seed(cfg.RANDOM_SEED + worker_id)
[docs]def worker_init_rand(worker_id):
"""
Init dataloader workers with torch.initial_seed().
torch.initial_seed() returns different seeds when called from different dataloader threads.
"""
random.seed(torch.initial_seed())
np.random.seed(torch.initial_seed() % 2 ** 32)
[docs]def get_dataloader(dataset, fix_seed=True, shuffle=False):
return torch.utils.data.DataLoader(
dataset, batch_size=cfg.BATCH_SIZE, shuffle=shuffle, num_workers=cfg.DATALOADER_NUM, collate_fn=collate_fn,
pin_memory=False, worker_init_fn=worker_init_fix if fix_seed else worker_init_rand
)