import torch
from torch import Tensor
[docs]def feature_align(raw_feature: Tensor, P: Tensor, ns_t: Tensor, ori_size: tuple, device=None) -> Tensor:
r"""
Perform feature align on the image feature map.
Feature align performs bi-linear interpolation on the image feature map. This operation is inspired by "ROIAlign"
in `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_.
:param raw_feature: :math:`(b\times c \times w \times h)` raw feature map. :math:`b`: batch size, :math:`c`: number
of feature channels, :math:`w`: feature map width, :math:`h`: feature map height
:param P: :math:`(b\times n \times 2)` point set containing point coordinates. The coordinates are at the scale of
the original image size. :math:`n`: number of points
:param ns_t: :math:`(b)` number of exact points. We support batched instances with different number of nodes, and
``ns_t`` is required to specify the exact number of nodes of each instance in the batch.
:param ori_size: size of the original image. Since the point coordinates are in the scale of the original image
size, this parameter is required.
:param device: output device. If not specified, it will be the same as the input
:return: :math:`(b\times c \times n)` extracted feature vectors
"""
if device is None:
device = raw_feature.device
batch_num = raw_feature.shape[0]
channel_num = raw_feature.shape[1]
n_max = P.shape[1]
ori_size = torch.tensor(ori_size, dtype=torch.float32, device=device)
F = torch.zeros(batch_num, channel_num, n_max, dtype=torch.float32, device=device)
for idx, feature in enumerate(raw_feature):
n = ns_t[idx]
feat_size = torch.as_tensor(feature.shape[1:3], dtype=torch.float32, device=device)
_P = P[idx, 0:n]
interp_2d(feature, _P, ori_size, feat_size, out=F[idx, :, 0:n])
return F
[docs]def interp_2d(z: Tensor, P: Tensor, ori_size: Tensor, feat_size: Tensor, out=None, device=None) -> Tensor:
r"""
Interpolate in 2d grid space. z can be 3-dimensional where the first dimension is feature dimension.
:param z: :math:`(c\times w\times h)` feature map. :math:`c`: number of feature channels, :math:`w`: feature map
width, :math:`h`: feature map height
:param P: :math:`(n\times 2)` point set containing point coordinates. The coordinates are at the scale of
the original image size. :math:`n`: number of points
:param ori_size: :math:`(2)` size of the original image
:param feat_size: :math:`(2)` size of the feature map
:param out: optional output tensor
:param device: output device. If not specified, it will be the same as the input
:return: :math:`(c \times n)` extracted feature vectors
"""
if device is None:
device = z.device
step = ori_size / feat_size
if out is None:
out = torch.zeros(z.shape[0], P.shape[0], dtype=torch.float32, device=device)
for i, p in enumerate(P):
p = (p - step / 2) / ori_size * feat_size
out[:, i] = bilinear_interpolate(z, p[0], p[1])
return out
[docs]def bilinear_interpolate(im: Tensor, x: Tensor, y: Tensor, device=None):
r"""
Bi-linear interpolate 3d feature map to 2d coordinate (x, y).
The coordinates are at the same scale of :math:`w\times h`.
:param im: :math:`(c\times w\times h)` feature map
:param x: :math:`(1)` x coordinate
:param y: :math:`(1)` y coordinate
:param device: output device. If not specified, it will be the same as the input
:return: :math:`(c)` interpolated feature vector
"""
if device is None:
device = im.device
x = x.to(torch.float32).to(device)
y = y.to(torch.float32).to(device)
x0 = torch.floor(x)
x1 = x0 + 1
y0 = torch.floor(y)
y1 = y0 + 1
x0 = torch.clamp(x0, 0, im.shape[2] - 1)
x1 = torch.clamp(x1, 0, im.shape[2] - 1)
y0 = torch.clamp(y0, 0, im.shape[1] - 1)
y1 = torch.clamp(y1, 0, im.shape[1] - 1)
x0 = x0.to(torch.int32).to(device)
x1 = x1.to(torch.int32).to(device)
y0 = y0.to(torch.int32).to(device)
y1 = y1.to(torch.int32).to(device)
Ia = im[:, y0, x0]
Ib = im[:, y1, x0]
Ic = im[:, y0, x1]
Id = im[:, y1, x1]
# to perform nearest neighbor interpolation if out of bounds
if x0 == x1:
if x0 == 0:
x0 -= 1
else:
x1 += 1
if y0 == y1:
if y0 == 0:
y0 -= 1
else:
y1 += 1
x0 = x0.to(torch.float32).to(device)
x1 = x1.to(torch.float32).to(device)
y0 = y0.to(torch.float32).to(device)
y1 = y1.to(torch.float32).to(device)
wa = (x1 - x) * (y1 - y)
wb = (x1 - x) * (y - y0)
wc = (x - x0) * (y1 - y)
wd = (x - x0) * (y - y0)
out = Ia * wa + Ib * wb + Ic * wc + Id * wd
return out