import torch
import torch.nn as nn
from src.utils.sparse import sbmm
[docs]class SpectralMatching(nn.Module):
"""
Spectral Graph Matching solver.
Also known as Power Iteration layer, which computes the leading eigenvector of input matrix.
For every iteration,
v_k+1 = M * v_k / ||M * v_k||_2
Parameter: maximum iteration max_iter
Input: input matrix M
(optional) initialization vector v0. If not specified, v0 will be initialized with all 1.
Output: computed eigenvector v
"""
def __init__(self, max_iter=50, stop_thresh=2e-7):
super(SpectralMatching, self).__init__()
self.max_iter = max_iter
self.stop_thresh = stop_thresh
[docs] def forward(self, M, v0=None, **kwargs):
batch_num = M.shape[0]
mn = M.shape[1]
if v0 is None:
v0 = torch.ones(batch_num, mn, 1, dtype=M.dtype, device=M.device)
v = vlast = v0
for i in range(self.max_iter):
if M.is_sparse:
v = sbmm(M, v)
else:
v = torch.bmm(M, v)
n = torch.norm(v, p=2, dim=1)
v = torch.matmul(v, (1 / n).view(batch_num, 1, 1))
if torch.norm(v - vlast) < self.stop_thresh:
return v.view(batch_num, -1)
vlast = v
return v.view(batch_num, -1)
if __name__ == '__main__':
from torch.autograd import gradcheck
input = (torch.randn(3, 40, 40, dtype=torch.double, requires_grad=True),)
pi = SpectralMatching()
test = gradcheck(pi, input, eps=1e-6, atol=1e-4)
print(test)