GumbelSinkhorn¶
- class src.lap_solvers.sinkhorn.GumbelSinkhorn(max_iter=10, tau=1.0, epsilon=0.0001, batched_operation=False)[source]¶
Gumbel Sinkhorn Layer turns the input matrix into a bi-stochastic matrix. See details in “Mena et al. Learning Latent Permutations with Gumbel-Sinkhorn Networks. ICLR 2018”
- Parameters
max_iter – maximum iterations (default:
10
)tau – the hyper parameter :math:` au` controlling the temperature (default:
1
)epsilon – a small number for numerical stability (default:
1e-4
)batched_operation – apply batched_operation for better efficiency (but may cause issues for back-propagation, default:
False
)
Note
This module only supports log-scale Sinkhorn operation.
- forward(s: torch.Tensor, nrows: Optional[torch.Tensor] = None, ncols: Optional[torch.Tensor] = None, sample_num=5, dummy_row=False) torch.Tensor [source]¶
- Parameters
s – \((b\times n_1 \times n_2)\) input 3d tensor. \(b\): batch size
nrows – \((b)\) number of objects in dim1
ncols – \((b)\) number of objects in dim2
sample_num – number of samples
dummy_row – whether to add dummy rows (rows whose elements are all 0) to pad the matrix to square matrix. default:
False
- Returns
\((b m\times n_1 \times n_2)\) the computed doubly-stochastic matrix. \(m\): number of samples (
sample_num
)
The samples are stacked at the fist dimension of the output tensor. You may reshape the output tensor
s
as:s = torch.reshape(s, (-1, sample_num, s.shape[1], s.shape[2]))
Note
We support batched instances with different number of nodes, therefore
nrows
andncols
are required to specify the exact number of objects of each dimension in the batch. If not specified, we assume the batched matrices are not padded.Note
The original Sinkhorn algorithm only works for square matrices. To handle cases where the graphs to be matched have different number of nodes, it is a common practice to add dummy rows to construct a square matrix. After the row and column normalizations, the padded rows are discarded.
Note
We assume row number <= column number. If not, the input matrix will be transposed.
- training: bool¶