# CrossEntropyLoss¶

class src.loss_func.CrossEntropyLoss[source]

Multi-class cross entropy loss between two permutations.

$L_{ce} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j}\right)$

where $$\mathcal{V}_1, \mathcal{V}_2$$ are vertex sets for two graphs.

Note

For batched input, this loss function computes the averaged loss among all instances in the batch.

forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor[source]
Parameters
• pred_dsmat$$(b\times n_1 \times n_2)$$ predicted doubly-stochastic matrix $$(\mathbf{S})$$

• gt_perm$$(b\times n_1 \times n_2)$$ ground truth permutation matrix $$(\mathbf{X}^{gt})$$

• src_ns$$(b)$$ number of exact pairs in the first graph (also known as source graph).

• tgt_ns$$(b)$$ number of exact pairs in the second graph (also known as target graph).

Returns

$$(1)$$ averaged cross-entropy loss

Note

We support batched instances with different number of nodes, therefore src_ns and tgt_ns are required to specify the exact number of nodes of each instance in the batch.

training: bool