Source code for src.lap_solvers.sinkhorn

import torch
import torch.nn as nn
from torch import Tensor
import pygmtools as pygm


[docs]class Sinkhorn(nn.Module): r""" Sinkhorn algorithm turns the input matrix into a bi-stochastic matrix. Sinkhorn algorithm firstly applies an ``exp`` function with temperature :math:`\tau`: .. math:: \mathbf{S}_{i,j} = \exp \left(\frac{\mathbf{s}_{i,j}}{\tau}\right) And then turns the matrix into doubly-stochastic matrix by iterative row- and column-wise normalization: .. math:: \mathbf{S} &= \mathbf{S} \oslash (\mathbf{1}_{n_2} \mathbf{1}_{n_2}^\top \cdot \mathbf{S}) \\ \mathbf{S} &= \mathbf{S} \oslash (\mathbf{S} \cdot \mathbf{1}_{n_2} \mathbf{1}_{n_2}^\top) where :math:`\oslash` means element-wise division, :math:`\mathbf{1}_n` means a column-vector with length :math:`n` whose elements are all :math:`1`\ s. :param max_iter: maximum iterations (default: ``10``) :param tau: the hyper parameter :math:`\tau` controlling the temperature (default: ``1``) :param epsilon: a small number for numerical stability (default: ``1e-4``) :param log_forward: apply log-scale computation for better numerical stability (default: ``True``) :param batched_operation: apply batched_operation for better efficiency (but may cause issues for back-propagation, default: ``False``) .. note:: ``tau`` is an important hyper parameter to be set for Sinkhorn algorithm. ``tau`` controls the distance between the predicted doubly-stochastic matrix, and the discrete permutation matrix computed by Hungarian algorithm (see :func:`~src.lap_solvers.hungarian.hungarian`). Given a small ``tau``, Sinkhorn performs more closely to Hungarian, at the cost of slower convergence speed and reduced numerical stability. .. note:: We recommend setting ``log_forward=True`` because it is more numerically stable. It provides more precise gradient in back propagation and helps the model to converge better and faster. .. note:: Setting ``batched_operation=True`` may be preferred when you are doing inference with this module and do not need the gradient. """ def __init__(self, max_iter: int=10, tau: float=1., epsilon: float=1e-4, log_forward: bool=True, batched_operation: bool=False): super(Sinkhorn, self).__init__() self.max_iter = max_iter self.tau = tau self.epsilon = epsilon self.log_forward = log_forward if not log_forward: print('Warning: Sinkhorn algorithm without log forward is deprecated because log_forward is more stable.') self.batched_operation = batched_operation # batched operation may cause instability in backward computation, # but will boost computation.
[docs] def forward(self, s: Tensor, nrows: Tensor=None, ncols: Tensor=None, dummy_row: bool=False) -> Tensor: r""" :param s: :math:`(b\times n_1 \times n_2)` input 3d tensor. :math:`b`: batch size :param nrows: :math:`(b)` number of objects in dim1 :param ncols: :math:`(b)` number of objects in dim2 :param dummy_row: whether to add dummy rows (rows whose elements are all 0) to pad the matrix to square matrix. default: ``False`` :return: :math:`(b\times n_1 \times n_2)` the computed doubly-stochastic matrix .. note:: We support batched instances with different number of nodes, therefore ``nrows`` and ``ncols`` are required to specify the exact number of objects of each dimension in the batch. If not specified, we assume the batched matrices are not padded. .. note:: The original Sinkhorn algorithm only works for square matrices. To handle cases where the graphs to be matched have different number of nodes, it is a common practice to add dummy rows to construct a square matrix. After the row and column normalizations, the padded rows are discarded. .. note:: We assume row number <= column number. If not, the input matrix will be transposed. """ if self.log_forward: return self.forward_log(s, nrows, ncols, dummy_row) else: return self.forward_ori(s, nrows, ncols, dummy_row) # deprecated
[docs] def forward_log(self, s, nrows=None, ncols=None, dummy_row=False): """Compute sinkhorn with row/column normalization in the log space.""" return pygm.sinkhorn(s, n1=nrows, n2=ncols, dummy_row=dummy_row, max_iter=self.max_iter, tau=self.tau, batched_operation=self.batched_operation, backend='pytorch')
[docs] def forward_ori(self, s, nrows=None, ncols=None, dummy_row=False): r""" Computing sinkhorn with row/column normalization. .. warning:: This function is deprecated because :meth:`~src.lap_solvers.sinkhorn.Sinkhorn.forward_log` is more numerically stable. """ if len(s.shape) == 2: s = s.unsqueeze(0) matrix_input = True elif len(s.shape) == 3: matrix_input = False else: raise ValueError('input data shape not understood.') batch_size = s.shape[0] #s = s.to(dtype=dtype) if nrows is None: nrows = [s.shape[1] for _ in range(batch_size)] if ncols is None: ncols = [s.shape[2] for _ in range(batch_size)] # tau scaling ret_s = torch.zeros_like(s) for b, n in enumerate(nrows): ret_s[b, 0:n, 0:ncols[b]] = \ nn.functional.softmax(s[b, 0:n, 0:ncols[b]] / self.tau, dim=-1) s = ret_s # add dummy elements if dummy_row: dummy_shape = list(s.shape) dummy_shape[1] = s.shape[2] - s.shape[1] #s = torch.cat((s, torch.full(dummy_shape, self.epsilon * 10).to(s.device)), dim=1) #nrows = nrows + dummy_shape[1] # non in-place s = torch.cat((s, torch.full(dummy_shape, 0.).to(s.device)), dim=1) ori_nrows = nrows nrows = ncols for b in range(batch_size): s[b, ori_nrows[b]:nrows[b], :ncols[b]] = self.epsilon row_norm_ones = torch.zeros(batch_size, s.shape[1], s.shape[1], device=s.device, dtype=s.dtype) # size: row x row col_norm_ones = torch.zeros(batch_size, s.shape[2], s.shape[2], device=s.device, dtype=s.dtype) # size: col x col for b in range(batch_size): row_slice = slice(0, nrows[b]) col_slice = slice(0, ncols[b]) row_norm_ones[b, row_slice, row_slice] = 1 col_norm_ones[b, col_slice, col_slice] = 1 s += self.epsilon for i in range(self.max_iter): if i % 2 == 0: # column norm #ones = torch.ones(batch_size, s.shape[1], s.shape[1], device=s.device) sum = torch.sum(torch.mul(s.unsqueeze(3), col_norm_ones.unsqueeze(1)), dim=2) else: # row norm # ones = torch.ones(batch_size, s.shape[2], s.shape[2], device=s.device) sum = torch.sum(torch.mul(row_norm_ones.unsqueeze(3), s.unsqueeze(1)), dim=2) tmp = torch.zeros_like(s) for b in range(batch_size): row_slice = slice(0, nrows[b] if nrows is not None else s.shape[2]) col_slice = slice(0, ncols[b] if ncols is not None else s.shape[1]) tmp[b, row_slice, col_slice] = 1 / sum[b, row_slice, col_slice] s = s * tmp if dummy_row: if dummy_shape[1] > 0: s = s[:, :-dummy_shape[1]] for b in range(batch_size): s[b, ori_nrows[b]:nrows[b], :ncols[b]] = 0 if matrix_input: s.squeeze_(0) return s
[docs]class GumbelSinkhorn(nn.Module): """ Gumbel Sinkhorn Layer turns the input matrix into a bi-stochastic matrix. See details in `"Mena et al. Learning Latent Permutations with Gumbel-Sinkhorn Networks. ICLR 2018" <https://arxiv.org/abs/1802.08665>`_ :param max_iter: maximum iterations (default: ``10``) :param tau: the hyper parameter :math:`\tau` controlling the temperature (default: ``1``) :param epsilon: a small number for numerical stability (default: ``1e-4``) :param batched_operation: apply batched_operation for better efficiency (but may cause issues for back-propagation, default: ``False``) .. note:: This module only supports log-scale Sinkhorn operation. """ def __init__(self, max_iter=10, tau=1., epsilon=1e-4, batched_operation=False): super(GumbelSinkhorn, self).__init__() self.sinkhorn = Sinkhorn(max_iter, tau, epsilon, batched_operation=batched_operation)
[docs] def forward(self, s: Tensor, nrows: Tensor=None, ncols: Tensor=None, sample_num=5, dummy_row=False) -> Tensor: r""" :param s: :math:`(b\times n_1 \times n_2)` input 3d tensor. :math:`b`: batch size :param nrows: :math:`(b)` number of objects in dim1 :param ncols: :math:`(b)` number of objects in dim2 :param sample_num: number of samples :param dummy_row: whether to add dummy rows (rows whose elements are all 0) to pad the matrix to square matrix. default: ``False`` :return: :math:`(b m\times n_1 \times n_2)` the computed doubly-stochastic matrix. :math:`m`: number of samples (``sample_num``) The samples are stacked at the fist dimension of the output tensor. You may reshape the output tensor ``s`` as: :: s = torch.reshape(s, (-1, sample_num, s.shape[1], s.shape[2])) .. note:: We support batched instances with different number of nodes, therefore ``nrows`` and ``ncols`` are required to specify the exact number of objects of each dimension in the batch. If not specified, we assume the batched matrices are not padded. .. note:: The original Sinkhorn algorithm only works for square matrices. To handle cases where the graphs to be matched have different number of nodes, it is a common practice to add dummy rows to construct a square matrix. After the row and column normalizations, the padded rows are discarded. .. note:: We assume row number <= column number. If not, the input matrix will be transposed. """ def sample_gumbel(t_like, eps=1e-20): """ randomly sample standard gumbel variables """ u = torch.empty_like(t_like).uniform_() return -torch.log(-torch.log(u + eps) + eps) s_rep = torch.repeat_interleave(s, sample_num, dim=0) s_rep = s_rep + sample_gumbel(s_rep) nrows_rep = torch.repeat_interleave(nrows, sample_num, dim=0) ncols_rep = torch.repeat_interleave(ncols, sample_num, dim=0) s_rep = self.sinkhorn(s_rep, nrows_rep, ncols_rep, dummy_row) #s_rep = torch.reshape(s_rep, (-1, sample_num, s_rep.shape[1], s_rep.shape[2])) return s_rep
if __name__ == '__main__': bs = Sinkhorn(max_iter=8, epsilon=1e-4) inp = torch.tensor([[[1., 0, 1.], [1., 0, 3.], [2., 0, 1.], [4., 0, 2.]]], requires_grad=True) outp = bs(inp, (3, 4)) print(outp) l = torch.sum(outp) l.backward() print(inp.grad * 1e10) outp2 = torch.tensor([[0.1, 0.1, 1], [2, 3, 4.]], requires_grad=True) l = torch.sum(outp2) l.backward() print(outp2.grad)