FocalLoss¶
- class src.loss_func.FocalLoss(gamma=0.0, eps=1e-15)[source]¶
Focal loss between two permutations.
\[L_{focal} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left((1-\mathbf{S}_{i,j})^{\gamma} \mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + \mathbf{S}_{i,j}^{\gamma} (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right)\]where \(\mathcal{V}_1, \mathcal{V}_2\) are vertex sets for two graphs, \(\gamma\) is the focal loss hyper parameter.
- Parameters
gamma – \(\gamma\) parameter for focal loss
eps – a small parameter for numerical stability
Note
For batched input, this loss function computes the averaged loss among all instances in the batch.
- forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor [source]¶
- Parameters
pred_dsmat – \((b\times n_1 \times n_2)\) predicted doubly-stochastic matrix \((\mathbf{S})\)
gt_perm – \((b\times n_1 \times n_2)\) ground truth permutation matrix \((\mathbf{X}^{gt})\)
src_ns – \((b)\) number of exact pairs in the first graph (also known as source graph).
tgt_ns – \((b)\) number of exact pairs in the second graph (also known as target graph).
- Returns
\((1)\) averaged focal loss
Note
We support batched instances with different number of nodes, therefore
src_ns
andtgt_ns
are required to specify the exact number of nodes of each instance in the batch.
- training: bool¶