Source code for src.lap_solvers.hungarian

import torch
import scipy.optimize as opt
import numpy as np
from multiprocessing import Pool
from torch import Tensor

[docs]def hungarian(s: Tensor, n1: Tensor=None, n2: Tensor=None, nproc: int=1) -> Tensor: r""" Solve optimal LAP permutation by hungarian algorithm. The time cost is :math:`O(n^3)`. :param s: :math:`(b\times n_1 \times n_2)` input 3d tensor. :math:`b`: batch size :param n1: :math:`(b)` number of objects in dim1 :param n2: :math:`(b)` number of objects in dim2 :param nproc: number of parallel processes (default: ``nproc=1`` for no parallel) :return: :math:`(b\times n_1 \times n_2)` optimal permutation matrix .. note:: We support batched instances with different number of nodes, therefore ``n1`` and ``n2`` are required to specify the exact number of objects of each dimension in the batch. If not specified, we assume the batched matrices are not padded. """ if len(s.shape) == 2: s = s.unsqueeze(0) matrix_input = True elif len(s.shape) == 3: matrix_input = False else: raise ValueError('input data shape not understood: {}'.format(s.shape)) device = s.device batch_num = s.shape[0] perm_mat = s.cpu().detach().numpy() * -1 if n1 is not None: n1 = n1.cpu().numpy() else: n1 = [None] * batch_num if n2 is not None: n2 = n2.cpu().numpy() else: n2 = [None] * batch_num if nproc > 1: with Pool(processes=nproc) as pool: mapresult = pool.starmap_async(_hung_kernel, zip(perm_mat, n1, n2)) perm_mat = np.stack(mapresult.get()) else: perm_mat = np.stack([_hung_kernel(perm_mat[b], n1[b], n2[b]) for b in range(batch_num)]) perm_mat = torch.from_numpy(perm_mat).to(device) if matrix_input: perm_mat.squeeze_(0) return perm_mat
def _hung_kernel(s: torch.Tensor, n1=None, n2=None): if n1 is None: n1 = s.shape[0] if n2 is None: n2 = s.shape[1] row, col = opt.linear_sum_assignment(s[:n1, :n2]) perm_mat = np.zeros_like(s) perm_mat[row, col] = 1 return perm_mat