CrossEntropyLoss

class src.loss_func.CrossEntropyLoss[source]

Multi-class cross entropy loss between two permutations.

\[L_{ce} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j}\right)\]

where \(\mathcal{V}_1, \mathcal{V}_2\) are vertex sets for two graphs.

Note

For batched input, this loss function computes the averaged loss among all instances in the batch.

forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor[source]
Parameters
  • pred_dsmat\((b\times n_1 \times n_2)\) predicted doubly-stochastic matrix \((\mathbf{S})\)

  • gt_perm\((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)

  • src_ns\((b)\) number of exact pairs in the first graph (also known as source graph).

  • tgt_ns\((b)\) number of exact pairs in the second graph (also known as target graph).

Returns

\((1)\) averaged cross-entropy loss

Note

We support batched instances with different number of nodes, therefore src_ns and tgt_ns are required to specify the exact number of nodes of each instance in the batch.

training: bool