FocalLoss¶

class src.loss_func.FocalLoss(gamma=0.0, eps=1e-15)[source]

Focal loss between two permutations.

$L_{focal} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left((1-\mathbf{S}_{i,j})^{\gamma} \mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + \mathbf{S}_{i,j}^{\gamma} (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right)$

where $$\mathcal{V}_1, \mathcal{V}_2$$ are vertex sets for two graphs, $$\gamma$$ is the focal loss hyper parameter.

Parameters
• gamma$$\gamma$$ parameter for focal loss

• eps – a small parameter for numerical stability

Note

For batched input, this loss function computes the averaged loss among all instances in the batch.

forward(pred_dsmat: torch.Tensor, gt_perm: torch.Tensor, src_ns: torch.Tensor, tgt_ns: torch.Tensor) torch.Tensor[source]
Parameters
• pred_dsmat$$(b\times n_1 \times n_2)$$ predicted doubly-stochastic matrix $$(\mathbf{S})$$

• gt_perm$$(b\times n_1 \times n_2)$$ ground truth permutation matrix $$(\mathbf{X}^{gt})$$

• src_ns$$(b)$$ number of exact pairs in the first graph (also known as source graph).

• tgt_ns$$(b)$$ number of exact pairs in the second graph (also known as target graph).

Returns

$$(1)$$ averaged focal loss

Note

We support batched instances with different number of nodes, therefore src_ns and tgt_ns are required to specify the exact number of nodes of each instance in the batch.

training: bool