# Source code for src.build_graphs

import torch
from torch import Tensor
from scipy.spatial import Delaunay
from scipy.spatial.qhull import QhullError

import itertools
import numpy as np

from typing import Tuple

[docs]def build_graphs(P_np: np.ndarray, n: int, n_pad: int=None, edge_pad: int=None, stg: str='fc', sym: bool=True,
thre: int=0) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
r"""
Build graph matrix :math:\mathbf G, \mathbf H from point set :math:\mathbf P.
This function supports only cpu operations in numpy.
:math:\mathbf G, \mathbf H are constructed from adjacency matrix :math:\mathbf A:
:math:\mathbf A = \mathbf G \cdot \mathbf H^\top

:param P_np: :math:(n\times 2) point set containing point coordinates
:param n: number of exact points in the point set
:param stg: strategy to build graphs. Options: fc, near, tri
:param sym: True for a symmetric adjacency, False for half adjacency (A contains only the upper half)
:param thre: The threshold value of 'near' strategy
:return: :math:A, :math:G, :math:H, edge_num

The possible options for stg:
::

'fc'(default): construct a fully-connected graph
'near': construct a fully-connected graph, but edges longer than thre are removed
'tri': apply Delaunay triangulation

An illustration of :math:\mathbf G, \mathbf H with their connections to the graph, the adjacency matrix,
the incident matrix is

.. image:: ../../images/build_graphs_GH.png
"""

assert stg in ('fc', 'tri', 'near'), 'No strategy named {} found.'.format(stg)

if stg == 'tri':
A = delaunay_triangulate(P_np[0:n, :])
elif stg == 'near':
A = fully_connect(P_np[0:n, :], thre=thre)
else:
A = fully_connect(P_np[0:n, :])
edge_num = int(np.sum(A, axis=(0, 1)))
assert n > 0 and edge_num > 0, 'Error in n = {} and edge_num = {}'.format(n, edge_num)

if n_pad is None:
if edge_pad is None:
assert n_pad >= n
assert edge_pad >= edge_num

edge_idx = 0
for i in range(n):
if sym:
range_j = range(n)
else:
range_j = range(i, n)
for j in range_j:
if A[i, j] == 1:
G[i, edge_idx] = 1
H[j, edge_idx] = 1
edge_idx += 1

return A, G, H, edge_num

[docs]def delaunay_triangulate(P: np.ndarray) -> np.ndarray:
r"""
Perform delaunay triangulation on point set P.

:param P: :math:(n\times 2) point set
:return: adjacency matrix :math:A
"""
n = P.shape
if n < 3:
A = fully_connect(P)
else:
try:
d = Delaunay(P)
#assert d.coplanar.size == 0, 'Delaunay triangulation omits points.'
A = np.zeros((n, n))
for simplex in d.simplices:
for pair in itertools.permutations(simplex, 2):
A[pair] = 1
except QhullError as err:
print('Delaunay triangulation error detected. Return fully-connected graph.')
print('Traceback:')
print(err)
A = fully_connect(P)
return A

[docs]def fully_connect(P: np.ndarray, thre=None) -> np.ndarray:
r"""
Return the adjacency matrix of a fully-connected graph.

:param P: :math:(n\times 2) point set
:param thre: edges that are longer than this threshold will be removed
:return: adjacency matrix :math:A
"""
n = P.shape
A = np.ones((n, n)) - np.eye(n)
if thre is not None:
for i in range(n):
for j in range(i):
if np.linalg.norm(P[i] - P[j]) > thre:
A[i, j] = 0
A[j, i] = 0
return A

[docs]def make_grids(start, stop, num) -> np.ndarray:
r"""
Make grids.

This function supports only cpu operations in numpy.

:param start: start index in all dimensions
:param stop: stop index in all dimensions
:param num: number of grids in each dimension
:return: point set P
"""
length = np.prod(num)
P = np.zeros((length, len(num)), dtype=np.float32)
assert len(start) == len(stop) == len(num)
for i, (begin, end, n) in enumerate(zip(start, stop, num)):
g = np.linspace(begin, end, n + 1)
g -= (g - g) / 2
g = g[1:]
P[:, i] = np.reshape(np.repeat([g], length / n, axis=i), length)
return P

[docs]def reshape_edge_feature(F: Tensor, G: Tensor, H: Tensor, device=None) -> Tensor:
r"""
Given point-level features extracted from images, reshape it into edge feature matrix :math:X,
where features are arranged by the order of :math:G, :math:H.

.. math::
\mathbf{X}_{e_{ij}} = concat(\mathbf{F}_i, \mathbf{F}_j)

where :math:e_{ij} means an edge connecting nodes :math:i, j

:param F: :math:(b\times d \times n) extracted point-level feature matrix.
:math:b: batch size. :math:d: feature dimension. :math:n: number of nodes.
:param G: :math:(b\times n \times e) factorized adjacency matrix, where :math:\mathbf A = \mathbf G \cdot \mathbf H^\top. :math:e: number of edges.
:param H: :math:(b\times n \times e) factorized adjacency matrix, where :math:\mathbf A = \mathbf G \cdot \mathbf H^\top
:param device: device. If not specified, it will be the same as the input
:return: edge feature matrix X :math:(b \times 2d \times e)
"""
if device is None:
device = F.device

batch_num = F.shape
feat_dim = F.shape
point_num, edge_num = G.shape[1:3]
X = torch.zeros(batch_num, 2 * feat_dim, edge_num, dtype=torch.float32, device=device)
X[:, 0:feat_dim, :] = torch.matmul(F, G)
X[:, feat_dim:2*feat_dim, :] = torch.matmul(F, H)

return X