# Source code for src.lap_solvers.hungarian

import torch
import scipy.optimize as opt
import numpy as np
from multiprocessing import Pool
from torch import Tensor

[docs]def hungarian(s: Tensor, n1: Tensor=None, n2: Tensor=None, nproc: int=1) -> Tensor:
r"""
Solve optimal LAP permutation by hungarian algorithm. The time cost is :math:O(n^3).

:param s: :math:(b\times n_1 \times n_2) input 3d tensor. :math:b: batch size
:param n1: :math:(b) number of objects in dim1
:param n2: :math:(b) number of objects in dim2
:param nproc: number of parallel processes (default: nproc=1 for no parallel)
:return: :math:(b\times n_1 \times n_2) optimal permutation matrix

.. note::
We support batched instances with different number of nodes, therefore n1 and n2 are
required to specify the exact number of objects of each dimension in the batch. If not specified, we assume
the batched matrices are not padded.
"""
if len(s.shape) == 2:
s = s.unsqueeze(0)
matrix_input = True
elif len(s.shape) == 3:
matrix_input = False
else:
raise ValueError('input data shape not understood: {}'.format(s.shape))

device = s.device
batch_num = s.shape

perm_mat = s.cpu().detach().numpy() * -1
if n1 is not None:
n1 = n1.cpu().numpy()
else:
n1 = [None] * batch_num
if n2 is not None:
n2 = n2.cpu().numpy()
else:
n2 = [None] * batch_num

if nproc > 1:
with Pool(processes=nproc) as pool:
mapresult = pool.starmap_async(_hung_kernel, zip(perm_mat, n1, n2))
perm_mat = np.stack(mapresult.get())
else:
perm_mat = np.stack([_hung_kernel(perm_mat[b], n1[b], n2[b]) for b in range(batch_num)])

perm_mat = torch.from_numpy(perm_mat).to(device)

if matrix_input:
perm_mat.squeeze_(0)

return perm_mat

def _hung_kernel(s: torch.Tensor, n1=None, n2=None):
if n1 is None:
n1 = s.shape
if n2 is None:
n2 = s.shape
row, col = opt.linear_sum_assignment(s[:n1, :n2])
perm_mat = np.zeros_like(s)
perm_mat[row, col] = 1
return perm_mat