Source code for src.build_graphs

import torch
from torch import Tensor
from scipy.spatial import Delaunay
from scipy.spatial.qhull import QhullError

import itertools
import numpy as np

from typing import Tuple


[docs]def build_graphs(P_np: np.ndarray, n: int, n_pad: int=None, edge_pad: int=None, stg: str='fc', sym: bool=True, thre: int=0) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: r""" Build graph matrix :math:`\mathbf G, \mathbf H` from point set :math:`\mathbf P`. This function supports only cpu operations in numpy. :math:`\mathbf G, \mathbf H` are constructed from adjacency matrix :math:`\mathbf A`: :math:`\mathbf A = \mathbf G \cdot \mathbf H^\top` :param P_np: :math:`(n\times 2)` point set containing point coordinates :param n: number of exact points in the point set :param n_pad: padded node length :param edge_pad: padded edge length :param stg: strategy to build graphs. Options: ``fc``, ``near``, ``tri`` :param sym: True for a symmetric adjacency, False for half adjacency (A contains only the upper half) :param thre: The threshold value of 'near' strategy :return: :math:`A`, :math:`G`, :math:`H`, edge_num The possible options for ``stg``: :: 'fc'(default): construct a fully-connected graph 'near': construct a fully-connected graph, but edges longer than ``thre`` are removed 'tri': apply Delaunay triangulation An illustration of :math:`\mathbf G, \mathbf H` with their connections to the graph, the adjacency matrix, the incident matrix is .. image:: ../../images/build_graphs_GH.png """ assert stg in ('fc', 'tri', 'near'), 'No strategy named {} found.'.format(stg) if stg == 'tri': A = delaunay_triangulate(P_np[0:n, :]) elif stg == 'near': A = fully_connect(P_np[0:n, :], thre=thre) else: A = fully_connect(P_np[0:n, :]) edge_num = int(np.sum(A, axis=(0, 1))) assert n > 0 and edge_num > 0, 'Error in n = {} and edge_num = {}'.format(n, edge_num) if n_pad is None: n_pad = n if edge_pad is None: edge_pad = edge_num assert n_pad >= n assert edge_pad >= edge_num G = np.zeros((n_pad, edge_pad), dtype=np.float32) H = np.zeros((n_pad, edge_pad), dtype=np.float32) edge_idx = 0 for i in range(n): if sym: range_j = range(n) else: range_j = range(i, n) for j in range_j: if A[i, j] == 1: G[i, edge_idx] = 1 H[j, edge_idx] = 1 edge_idx += 1 return A, G, H, edge_num
[docs]def delaunay_triangulate(P: np.ndarray) -> np.ndarray: r""" Perform delaunay triangulation on point set P. :param P: :math:`(n\times 2)` point set :return: adjacency matrix :math:`A` """ n = P.shape[0] if n < 3: A = fully_connect(P) else: try: d = Delaunay(P) #assert d.coplanar.size == 0, 'Delaunay triangulation omits points.' A = np.zeros((n, n)) for simplex in d.simplices: for pair in itertools.permutations(simplex, 2): A[pair] = 1 except QhullError as err: print('Delaunay triangulation error detected. Return fully-connected graph.') print('Traceback:') print(err) A = fully_connect(P) return A
[docs]def fully_connect(P: np.ndarray, thre=None) -> np.ndarray: r""" Return the adjacency matrix of a fully-connected graph. :param P: :math:`(n\times 2)` point set :param thre: edges that are longer than this threshold will be removed :return: adjacency matrix :math:`A` """ n = P.shape[0] A = np.ones((n, n)) - np.eye(n) if thre is not None: for i in range(n): for j in range(i): if np.linalg.norm(P[i] - P[j]) > thre: A[i, j] = 0 A[j, i] = 0 return A
[docs]def make_grids(start, stop, num) -> np.ndarray: r""" Make grids. This function supports only cpu operations in numpy. :param start: start index in all dimensions :param stop: stop index in all dimensions :param num: number of grids in each dimension :return: point set P """ length = np.prod(num) P = np.zeros((length, len(num)), dtype=np.float32) assert len(start) == len(stop) == len(num) for i, (begin, end, n) in enumerate(zip(start, stop, num)): g = np.linspace(begin, end, n + 1) g -= (g[1] - g[0]) / 2 g = g[1:] P[:, i] = np.reshape(np.repeat([g], length / n, axis=i), length) return P
[docs]def reshape_edge_feature(F: Tensor, G: Tensor, H: Tensor, device=None) -> Tensor: r""" Given point-level features extracted from images, reshape it into edge feature matrix :math:`X`, where features are arranged by the order of :math:`G`, :math:`H`. .. math:: \mathbf{X}_{e_{ij}} = concat(\mathbf{F}_i, \mathbf{F}_j) where :math:`e_{ij}` means an edge connecting nodes :math:`i, j` :param F: :math:`(b\times d \times n)` extracted point-level feature matrix. :math:`b`: batch size. :math:`d`: feature dimension. :math:`n`: number of nodes. :param G: :math:`(b\times n \times e)` factorized adjacency matrix, where :math:`\mathbf A = \mathbf G \cdot \mathbf H^\top`. :math:`e`: number of edges. :param H: :math:`(b\times n \times e)` factorized adjacency matrix, where :math:`\mathbf A = \mathbf G \cdot \mathbf H^\top` :param device: device. If not specified, it will be the same as the input :return: edge feature matrix X :math:`(b \times 2d \times e)` """ if device is None: device = F.device batch_num = F.shape[0] feat_dim = F.shape[1] point_num, edge_num = G.shape[1:3] X = torch.zeros(batch_num, 2 * feat_dim, edge_num, dtype=torch.float32, device=device) X[:, 0:feat_dim, :] = torch.matmul(F, G) X[:, feat_dim:2*feat_dim, :] = torch.matmul(F, H) return X