Source code for src.loss_func

import torch
import torch.nn as nn
import torch.nn.functional as F
from src.lap_solvers.hungarian import hungarian
from torch import Tensor


[docs]class PermutationLoss(nn.Module): r""" Binary cross entropy loss between two permutations, also known as "permutation loss". Proposed by `"Wang et al. Learning Combinatorial Embedding Networks for Deep Graph Matching. ICCV 2019." <http://openaccess.thecvf.com/content_ICCV_2019/papers/Wang_Learning_Combinatorial_Embedding_Networks_for_Deep_Graph_Matching_ICCV_2019_paper.pdf>`_ .. math:: L_{perm} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right) where :math:`\mathcal{V}_1, \mathcal{V}_2` are vertex sets for two graphs. .. note:: For batched input, this loss function computes the averaged loss among all instances in the batch. """ def __init__(self): super(PermutationLoss, self).__init__()
[docs] def forward(self, pred_dsmat: Tensor, gt_perm: Tensor, src_ns: Tensor, tgt_ns: Tensor) -> Tensor: r""" :param pred_dsmat: :math:`(b\times n_1 \times n_2)` predicted doubly-stochastic matrix :math:`(\mathbf{S})` :param gt_perm: :math:`(b\times n_1 \times n_2)` ground truth permutation matrix :math:`(\mathbf{X}^{gt})` :param src_ns: :math:`(b)` number of exact pairs in the first graph (also known as source graph). :param tgt_ns: :math:`(b)` number of exact pairs in the second graph (also known as target graph). :return: :math:`(1)` averaged permutation loss .. note:: We support batched instances with different number of nodes, therefore ``src_ns`` and ``tgt_ns`` are required to specify the exact number of nodes of each instance in the batch. """ batch_num = pred_dsmat.shape[0] pred_dsmat = pred_dsmat.to(dtype=torch.float32) try: assert torch.all((pred_dsmat >= 0) * (pred_dsmat <= 1)) assert torch.all((gt_perm >= 0) * (gt_perm <= 1)) except AssertionError as err: print(pred_dsmat) raise err loss = torch.tensor(0.).to(pred_dsmat.device) n_sum = torch.zeros_like(loss) for b in range(batch_num): batch_slice = [b, slice(src_ns[b]), slice(tgt_ns[b])] loss += F.binary_cross_entropy( pred_dsmat[batch_slice], gt_perm[batch_slice], reduction='sum') n_sum += src_ns[b].to(n_sum.dtype).to(pred_dsmat.device) return loss / n_sum
[docs]class CrossEntropyLoss(nn.Module): r""" Multi-class cross entropy loss between two permutations. .. math:: L_{ce} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j}\right) where :math:`\mathcal{V}_1, \mathcal{V}_2` are vertex sets for two graphs. .. note:: For batched input, this loss function computes the averaged loss among all instances in the batch. """ def __init__(self): super(CrossEntropyLoss, self).__init__()
[docs] def forward(self, pred_dsmat: Tensor, gt_perm: Tensor, src_ns: Tensor, tgt_ns: Tensor) -> Tensor: r""" :param pred_dsmat: :math:`(b\times n_1 \times n_2)` predicted doubly-stochastic matrix :math:`(\mathbf{S})` :param gt_perm: :math:`(b\times n_1 \times n_2)` ground truth permutation matrix :math:`(\mathbf{X}^{gt})` :param src_ns: :math:`(b)` number of exact pairs in the first graph (also known as source graph). :param tgt_ns: :math:`(b)` number of exact pairs in the second graph (also known as target graph). :return: :math:`(1)` averaged cross-entropy loss .. note:: We support batched instances with different number of nodes, therefore ``src_ns`` and ``tgt_ns`` are required to specify the exact number of nodes of each instance in the batch. """ batch_num = pred_dsmat.shape[0] pred_dsmat = pred_dsmat.to(dtype=torch.float32) try: assert torch.all((pred_dsmat >= 0) * (pred_dsmat <= 1)) assert torch.all((gt_perm >= 0) * (gt_perm <= 1)) except AssertionError as err: print(pred_dsmat) raise err loss = torch.tensor(0.).to(pred_dsmat.device) n_sum = torch.zeros_like(loss) for b in range(batch_num): batch_slice = [b, slice(src_ns[b]), slice(tgt_ns[b])] gt_index = torch.max(gt_perm[batch_slice], dim=-1).indices loss += F.nll_loss( torch.log(pred_dsmat[batch_slice]), gt_index, reduction='sum') n_sum += src_ns[b].to(n_sum.dtype).to(pred_dsmat.device) return loss / n_sum
[docs]class PermutationLossHung(nn.Module): r""" Binary cross entropy loss between two permutations with Hungarian attention. The vanilla version without Hungarian attention is :class:`~src.loss_func.PermutationLoss`. .. math:: L_{hung} &=-\sum_{i\in\mathcal{V}_1,j\in\mathcal{V}_2}\mathbf{Z}_{ij}\left(\mathbf{X}^\text{gt}_{ij}\log \mathbf{S}_{ij}+\left(1-\mathbf{X}^{\text{gt}}_{ij}\right)\log\left(1-\mathbf{S}_{ij}\right)\right) \\ \mathbf{Z}&=\mathrm{OR}\left(\mathrm{Hungarian}(\mathbf{S}),\mathbf{X}^\text{gt}\right) where :math:`\mathcal{V}_1, \mathcal{V}_2` are vertex sets for two graphs. Hungarian attention highlights the entries where the model makes wrong decisions after the Hungarian step (which is the default discretization step during inference). Proposed by `"Yu et al. Learning deep graph matching with channel-independent embedding and Hungarian attention. ICLR 2020." <https://openreview.net/forum?id=rJgBd2NYPH>`_ .. note:: For batched input, this loss function computes the averaged loss among all instances in the batch. A working example for Hungarian attention: .. image:: ../../images/hungarian_attention.png """ def __init__(self): super(PermutationLossHung, self).__init__()
[docs] def forward(self, pred_dsmat: Tensor, gt_perm: Tensor, src_ns: Tensor, tgt_ns: Tensor) -> Tensor: r""" :param pred_dsmat: :math:`(b\times n_1 \times n_2)` predicted doubly-stochastic matrix :math:`(\mathbf{S})` :param gt_perm: :math:`(b\times n_1 \times n_2)` ground truth permutation matrix :math:`(\mathbf{X}^{gt})` :param src_ns: :math:`(b)` number of exact pairs in the first graph (also known as source graph). :param tgt_ns: :math:`(b)` number of exact pairs in the second graph (also known as target graph). :return: :math:`(1)` averaged permutation loss .. note:: We support batched instances with different number of nodes, therefore ``src_ns`` and ``tgt_ns`` are required to specify the exact number of nodes of each instance in the batch. """ batch_num = pred_dsmat.shape[0] assert torch.all((pred_dsmat >= 0) * (pred_dsmat <= 1)) assert torch.all((gt_perm >= 0) * (gt_perm <= 1)) dis_pred = hungarian(pred_dsmat, src_ns, tgt_ns) ali_perm = dis_pred + gt_perm ali_perm[ali_perm > 1.0] = 1.0 # Hung pred_dsmat = torch.mul(ali_perm, pred_dsmat) gt_perm = torch.mul(ali_perm, gt_perm) loss = torch.tensor(0.).to(pred_dsmat.device) n_sum = torch.zeros_like(loss) for b in range(batch_num): loss += F.binary_cross_entropy( pred_dsmat[b, :src_ns[b], :tgt_ns[b]], gt_perm[b, :src_ns[b], :tgt_ns[b]], reduction='sum') n_sum += src_ns[b].to(n_sum.dtype).to(pred_dsmat.device) return loss / n_sum
[docs]class OffsetLoss(nn.Module): r""" OffsetLoss Criterion computes a robust loss function based on image pixel offset. Proposed by `"Zanfir et al. Deep Learning of Graph Matching. CVPR 2018." <http://openaccess.thecvf.com/content_cvpr_2018/html/Zanfir_Deep_Learning_of_CVPR_2018_paper.html>`_ .. math:: \mathbf{d}_i =& \sum_{j \in V_2} \left( \mathbf{S}_{i, j} P_{2j} \right)- P_{1i} \\ L_{off} =& \sum_{i \in V_1} \sqrt{||\mathbf{d}_i - \mathbf{d}^{gt}_i||^2 + \epsilon} :math:`\mathbf{d}_i` is the displacement vector. See :class:`src.displacement_layer.Displacement` or more details :param epsilon: a small number for numerical stability :param norm: (optional) division taken to normalize the loss """ def __init__(self, epsilon: float=1e-5, norm=None): super(OffsetLoss, self).__init__() self.epsilon = epsilon self.norm = norm
[docs] def forward(self, d1: Tensor, d2: Tensor, mask: float=None) -> Tensor: """ :param d1: predicted displacement matrix :param d2: ground truth displacement matrix :param mask: (optional) dummy node mask :return: computed offset loss """ # Loss = Sum(Phi(d_i - d_i^gt)) # Phi(x) = sqrt(x^T * x + epsilon) if mask is None: mask = torch.ones_like(mask) x = d1 - d2 if self.norm is not None: x = x / self.norm xtx = torch.sum(x * x * mask, dim=-1) phi = torch.sqrt(xtx + self.epsilon) loss = torch.sum(phi) / d1.shape[0] return loss
[docs]class FocalLoss(nn.Module): r""" Focal loss between two permutations. .. math:: L_{focal} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left((1-\mathbf{S}_{i,j})^{\gamma} \mathbf{X}^{gt}_{i,j} \log \mathbf{S}_{i,j} + \mathbf{S}_{i,j}^{\gamma} (1-\mathbf{X}^{gt}_{i,j}) \log (1-\mathbf{S}_{i,j}) \right) where :math:`\mathcal{V}_1, \mathcal{V}_2` are vertex sets for two graphs, :math:`\gamma` is the focal loss hyper parameter. :param gamma: :math:`\gamma` parameter for focal loss :param eps: a small parameter for numerical stability .. note:: For batched input, this loss function computes the averaged loss among all instances in the batch. """ def __init__(self, gamma=0., eps=1e-15): super(FocalLoss, self).__init__() self.gamma = gamma self.eps = eps
[docs] def forward(self, pred_dsmat: Tensor, gt_perm: Tensor, src_ns: Tensor, tgt_ns: Tensor) -> Tensor: r""" :param pred_dsmat: :math:`(b\times n_1 \times n_2)` predicted doubly-stochastic matrix :math:`(\mathbf{S})` :param gt_perm: :math:`(b\times n_1 \times n_2)` ground truth permutation matrix :math:`(\mathbf{X}^{gt})` :param src_ns: :math:`(b)` number of exact pairs in the first graph (also known as source graph). :param tgt_ns: :math:`(b)` number of exact pairs in the second graph (also known as target graph). :return: :math:`(1)` averaged focal loss .. note:: We support batched instances with different number of nodes, therefore ``src_ns`` and ``tgt_ns`` are required to specify the exact number of nodes of each instance in the batch. """ batch_num = pred_dsmat.shape[0] pred_dsmat = pred_dsmat.to(dtype=torch.float32) assert torch.all((pred_dsmat >= 0) * (pred_dsmat <= 1)) assert torch.all((gt_perm >= 0) * (gt_perm <= 1)) loss = torch.tensor(0.).to(pred_dsmat.device) n_sum = torch.zeros_like(loss) for b in range(batch_num): x = pred_dsmat[b, :src_ns[b], :tgt_ns[b]] y = gt_perm[b, :src_ns[b], :tgt_ns[b]] loss += torch.sum( - (1 - x) ** self.gamma * y * torch.log(x + self.eps) - x ** self.gamma * (1 - y) * torch.log(1 - x + self.eps) ) n_sum += src_ns[b].to(n_sum.dtype).to(pred_dsmat.device) return loss / n_sum
[docs]class InnerProductLoss(nn.Module): r""" Inner product loss for self-supervised problems. .. math:: L_{ce} =- \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}^{gt}_{i,j} \mathbf{S}_{i,j}\right) where :math:`\mathcal{V}_1, \mathcal{V}_2` are vertex sets for two graphs. .. note:: For batched input, this loss function computes the averaged loss among all instances in the batch. """ def __init__(self): super(InnerProductLoss, self).__init__()
[docs] def forward(self, pred_dsmat: Tensor, gt_perm: Tensor, src_ns: Tensor, tgt_ns: Tensor) -> Tensor: r""" :param pred_dsmat: :math:`(b\times n_1 \times n_2)` predicted doubly-stochastic matrix :math:`(\mathbf{S})` :param gt_perm: :math:`(b\times n_1 \times n_2)` ground truth permutation matrix :math:`(\mathbf{X}^{gt})` :param src_ns: :math:`(b)` number of exact pairs in the first graph (also known as source graph). :param tgt_ns: :math:`(b)` number of exact pairs in the second graph (also known as target graph). :return: :math:`(1)` averaged inner product loss .. note:: We support batched instances with different number of nodes, therefore ``src_ns`` and ``tgt_ns`` are required to specify the exact number of nodes of each instance in the batch. """ batch_num = pred_dsmat.shape[0] pred_dsmat = pred_dsmat.to(dtype=torch.float32) try: assert torch.all((gt_perm >= 0) * (gt_perm <= 1)) except AssertionError as err: raise err loss = torch.tensor(0.).to(pred_dsmat.device) n_sum = torch.zeros_like(loss) for b in range(batch_num): batch_slice = [b, slice(src_ns[b]), slice(tgt_ns[b])] loss -= torch.sum(pred_dsmat[batch_slice] * gt_perm[batch_slice]) n_sum += src_ns[b].to(n_sum.dtype).to(pred_dsmat.device) return loss / n_sum
[docs]class HammingLoss(torch.nn.Module): r""" Hamming loss between two permutations. .. math:: L_{hamm} = \sum_{i \in \mathcal{V}_1, j \in \mathcal{V}_2} \left(\mathbf{X}_{i,j} (1-\mathbf{X}^{gt}_{i,j}) + (1-\mathbf{X}_{i,j}) \mathbf{X}^{gt}_{i,j}\right) where :math:`\mathcal{V}_1, \mathcal{V}_2` are vertex sets for two graphs. Firstly adopted by `"Rolinek et al. Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers. ECCV 2020." <https://arxiv.org/abs/2003.11657>`_ .. note:: Hamming loss is defined between two discrete matrices, and discretization will in general truncate gradient. A workaround may be using the `blackbox differentiation technique <https://arxiv.org/abs/1912.02175>`_. """ def __init__(self): super(HammingLoss, self).__init__()
[docs] def forward(self, pred_perm: Tensor, gt_perm: Tensor) -> Tensor: r""" :param pred_perm: :math:`(b\times n_1 \times n_2)` predicted permutation matrix :math:`(\mathbf{X})` :param gt_perm: :math:`(b\times n_1 \times n_2)` ground truth permutation matrix :math:`(\mathbf{X}^{gt})` :return: """ errors = pred_perm * (1.0 - gt_perm) + (1.0 - pred_perm) * gt_perm return errors.mean(dim=0).sum()