Source code for src.feature_align

import torch
from torch import Tensor


[docs]def feature_align(raw_feature: Tensor, P: Tensor, ns_t: Tensor, ori_size: tuple, device=None) -> Tensor: r""" Perform feature align on the image feature map. Feature align performs bi-linear interpolation on the image feature map. This operation is inspired by "ROIAlign" in `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_. :param raw_feature: :math:`(b\times c \times w \times h)` raw feature map. :math:`b`: batch size, :math:`c`: number of feature channels, :math:`w`: feature map width, :math:`h`: feature map height :param P: :math:`(b\times n \times 2)` point set containing point coordinates. The coordinates are at the scale of the original image size. :math:`n`: number of points :param ns_t: :math:`(b)` number of exact points. We support batched instances with different number of nodes, and ``ns_t`` is required to specify the exact number of nodes of each instance in the batch. :param ori_size: size of the original image. Since the point coordinates are in the scale of the original image size, this parameter is required. :param device: output device. If not specified, it will be the same as the input :return: :math:`(b\times c \times n)` extracted feature vectors """ if device is None: device = raw_feature.device batch_num = raw_feature.shape[0] channel_num = raw_feature.shape[1] n_max = P.shape[1] ori_size = torch.tensor(ori_size, dtype=torch.float32, device=device) F = torch.zeros(batch_num, channel_num, n_max, dtype=torch.float32, device=device) for idx, feature in enumerate(raw_feature): n = ns_t[idx] feat_size = torch.as_tensor(feature.shape[1:3], dtype=torch.float32, device=device) _P = P[idx, 0:n] interp_2d(feature, _P, ori_size, feat_size, out=F[idx, :, 0:n]) return F
[docs]def interp_2d(z: Tensor, P: Tensor, ori_size: Tensor, feat_size: Tensor, out=None, device=None) -> Tensor: r""" Interpolate in 2d grid space. z can be 3-dimensional where the first dimension is feature dimension. :param z: :math:`(c\times w\times h)` feature map. :math:`c`: number of feature channels, :math:`w`: feature map width, :math:`h`: feature map height :param P: :math:`(n\times 2)` point set containing point coordinates. The coordinates are at the scale of the original image size. :math:`n`: number of points :param ori_size: :math:`(2)` size of the original image :param feat_size: :math:`(2)` size of the feature map :param out: optional output tensor :param device: output device. If not specified, it will be the same as the input :return: :math:`(c \times n)` extracted feature vectors """ if device is None: device = z.device step = ori_size / feat_size if out is None: out = torch.zeros(z.shape[0], P.shape[0], dtype=torch.float32, device=device) for i, p in enumerate(P): p = (p - step / 2) / ori_size * feat_size out[:, i] = bilinear_interpolate(z, p[0], p[1]) return out
[docs]def bilinear_interpolate(im: Tensor, x: Tensor, y: Tensor, device=None): r""" Bi-linear interpolate 3d feature map to 2d coordinate (x, y). The coordinates are at the same scale of :math:`w\times h`. :param im: :math:`(c\times w\times h)` feature map :param x: :math:`(1)` x coordinate :param y: :math:`(1)` y coordinate :param device: output device. If not specified, it will be the same as the input :return: :math:`(c)` interpolated feature vector """ if device is None: device = im.device x = x.to(torch.float32).to(device) y = y.to(torch.float32).to(device) x0 = torch.floor(x) x1 = x0 + 1 y0 = torch.floor(y) y1 = y0 + 1 x0 = torch.clamp(x0, 0, im.shape[2] - 1) x1 = torch.clamp(x1, 0, im.shape[2] - 1) y0 = torch.clamp(y0, 0, im.shape[1] - 1) y1 = torch.clamp(y1, 0, im.shape[1] - 1) x0 = x0.to(torch.int32).to(device) x1 = x1.to(torch.int32).to(device) y0 = y0.to(torch.int32).to(device) y1 = y1.to(torch.int32).to(device) Ia = im[:, y0, x0] Ib = im[:, y1, x0] Ic = im[:, y0, x1] Id = im[:, y1, x1] # to perform nearest neighbor interpolation if out of bounds if x0 == x1: if x0 == 0: x0 -= 1 else: x1 += 1 if y0 == y1: if y0 == 0: y0 -= 1 else: y1 += 1 x0 = x0.to(torch.float32).to(device) x1 = x1.to(torch.float32).to(device) y0 = y0.to(torch.float32).to(device) y1 = y1.to(torch.float32).to(device) wa = (x1 - x) * (y1 - y) wb = (x1 - x) * (y - y0) wc = (x - x0) * (y1 - y) wd = (x - x0) * (y - y0) out = Ia * wa + Ib * wb + Ic * wc + Id * wd return out