InnerProductLoss

class src.loss_func.InnerProductLoss[source]

Inner product loss for self-supervised problems.

\[L_{ce} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \mathbf{S}_{i,j}\right)\]

where \(\mathcal{V}_1, \mathcal{V}_2\) are vertex sets for two graphs.

Note

For batched input, this loss function computes the averaged loss among all instances in the batch.

forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor[source]
Parameters
  • pred_dsmat\((b\times n_1 \times n_2)\) predicted doubly-stochastic matrix \((\mathbf{S})\)

  • gt_perm\((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)

  • src_ns\((b)\) number of exact pairs in the first graph (also known as source graph).

  • tgt_ns\((b)\) number of exact pairs in the second graph (also known as target graph).

Returns

\((1)\) averaged inner product loss

Note

We support batched instances with different number of nodes, therefore src_ns and tgt_ns are required to specify the exact number of nodes of each instance in the batch.

training: bool