InnerProductLoss¶
- class src.loss_func.InnerProductLoss[source]¶
Inner product loss for self-supervised problems.
\[L_{ce} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \mathbf{S}_{i,j}\right)\]where \(\mathcal{V}_1, \mathcal{V}_2\) are vertex sets for two graphs.
Note
For batched input, this loss function computes the averaged loss among all instances in the batch.
- forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor [source]¶
- Parameters
pred_dsmat – \((b\times n_1 \times n_2)\) predicted doubly-stochastic matrix \((\mathbf{S})\)
gt_perm – \((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)
src_ns – \((b)\) number of exact pairs in the first graph (also known as source graph).
tgt_ns – \((b)\) number of exact pairs in the second graph (also known as target graph).
- Returns
\((1)\) averaged inner product loss
Note
We support batched instances with different number of nodes, therefore
src_ns
andtgt_ns
are required to specify the exact number of nodes of each instance in the batch.
- training: bool¶