# PermutationLoss¶

class src.loss_func.PermutationLoss[source]

Binary cross entropy loss between two permutations, also known as “permutation loss”. Proposed by “Wang et al. Learning Combinatorial Embedding Networks for Deep Graph Matching. ICCV 2019.”

$L_{perm} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right)$

where $$\mathcal{V}_1, \mathcal{V}_2$$ are vertex sets for two graphs.

Note

For batched input, this loss function computes the averaged loss among all instances in the batch.

forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor[source]
Parameters
• pred_dsmat$$(b\times n_1 \times n_2)$$ predicted doubly-stochastic matrix $$(\mathbf{S})$$

• gt_perm$$(b\times n_1 \times n_2)$$ ground truth permutation matrix $$(\mathbf{X}^{gt})$$

• src_ns$$(b)$$ number of exact pairs in the first graph (also known as source graph).

• tgt_ns$$(b)$$ number of exact pairs in the second graph (also known as target graph).

Returns

$$(1)$$ averaged permutation loss

Note

We support batched instances with different number of nodes, therefore src_ns and tgt_ns are required to specify the exact number of nodes of each instance in the batch.

training: bool