FocalLoss

class src.loss_func.FocalLoss(gamma=0.0, eps=1e-15)[source]

Focal loss between two permutations.

\[L_{focal} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left((1-\mathbf{S}_{i,j})^{\gamma} \mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + \mathbf{S}_{i,j}^{\gamma} (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right)\]

where \(\mathcal{V}_1, \mathcal{V}_2\) are vertex sets for two graphs, \(\gamma\) is the focal loss hyper parameter.

Parameters
  • gamma\(\gamma\) parameter for focal loss

  • eps – a small parameter for numerical stability

Note

For batched input, this loss function computes the averaged loss among all instances in the batch.

forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor[source]
Parameters
  • pred_dsmat\((b\times n_1 \times n_2)\) predicted doubly-stochastic matrix \((\mathbf{S})\)

  • gt_perm\((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)

  • src_ns\((b)\) number of exact pairs in the first graph (also known as source graph).

  • tgt_ns\((b)\) number of exact pairs in the second graph (also known as target graph).

Returns

\((1)\) averaged focal loss

Note

We support batched instances with different number of nodes, therefore src_ns and tgt_ns are required to specify the exact number of nodes of each instance in the batch.

training: bool