import torch
import scipy.optimize as opt
import numpy as np
from multiprocessing import Pool
from torch import Tensor
[docs]def hungarian(s: Tensor, n1: Tensor=None, n2: Tensor=None, nproc: int=1) -> Tensor:
r"""
Solve optimal LAP permutation by hungarian algorithm. The time cost is :math:`O(n^3)`.
:param s: :math:`(b\times n_1 \times n_2)` input 3d tensor. :math:`b`: batch size
:param n1: :math:`(b)` number of objects in dim1
:param n2: :math:`(b)` number of objects in dim2
:param nproc: number of parallel processes (default: ``nproc=1`` for no parallel)
:return: :math:`(b\times n_1 \times n_2)` optimal permutation matrix
.. note::
We support batched instances with different number of nodes, therefore ``n1`` and ``n2`` are
required to specify the exact number of objects of each dimension in the batch. If not specified, we assume
the batched matrices are not padded.
"""
if len(s.shape) == 2:
s = s.unsqueeze(0)
matrix_input = True
elif len(s.shape) == 3:
matrix_input = False
else:
raise ValueError('input data shape not understood: {}'.format(s.shape))
device = s.device
batch_num = s.shape[0]
perm_mat = s.cpu().detach().numpy() * -1
if n1 is not None:
n1 = n1.cpu().numpy()
else:
n1 = [None] * batch_num
if n2 is not None:
n2 = n2.cpu().numpy()
else:
n2 = [None] * batch_num
if nproc > 1:
with Pool(processes=nproc) as pool:
mapresult = pool.starmap_async(_hung_kernel, zip(perm_mat, n1, n2))
perm_mat = np.stack(mapresult.get())
else:
perm_mat = np.stack([_hung_kernel(perm_mat[b], n1[b], n2[b]) for b in range(batch_num)])
perm_mat = torch.from_numpy(perm_mat).to(device)
if matrix_input:
perm_mat.squeeze_(0)
return perm_mat
def _hung_kernel(s: torch.Tensor, n1=None, n2=None):
if n1 is None:
n1 = s.shape[0]
if n2 is None:
n2 = s.shape[1]
row, col = opt.linear_sum_assignment(s[:n1, :n2])
perm_mat = np.zeros_like(s)
perm_mat[row, col] = 1
return perm_mat