Source code for src.utils.sparse

import sys
import os
import torch
from torch.autograd import Function
import numpy as np
import scipy.sparse as ssp

from src.sparse_torch import CSRMatrix3d, CSCMatrix3d

if 'SPHINX' not in os.environ:
    from torch.utils.cpp_extension import load
    bilinear_diag = load(name='bilinear_diag', sources=['src/extension/bilinear_diag/bilinear_diag.cpp',
                                                        'src/extension/bilinear_diag/bilinear_diag_cuda.cu'],
                         extra_include_paths=[
                             '/usr/include/python{}.{}/'.format(sys.version_info.major, sys.version_info.minor)]
                         )


[docs]def to_sparse(x, dense_dim=1): """ converts dense tensor x to sparse format """ x_typename = torch.typename(x).split('.')[-1] sparse_tensortype = getattr(torch.sparse, x_typename) indices = torch.nonzero(x, as_tuple=False)[:, :len(x.shape) - dense_dim + 1] if len(indices.shape) == 0: # if all elements are zeros return sparse_tensortype(*x.shape) indices = indices.t() values = x[tuple(indices[i] for i in range(indices.shape[0]))] return sparse_tensortype(indices, values, x.shape)
[docs]def sbmm(t1, t2): """ Perform bmm (Batch Matrix Matrix) for sparse x dense -> dense. """ return SparseDenseDenseBMM.apply(t1, t2)
[docs]def sbmm_diag(t1, t2): """ Perform bmm and diagonal for sparse x dense -> dense. The diagonalized result is returned in vector tensor. This is a wrapper function and does not support gradient. """ assert t1.is_sparse != t2.is_sparse, 't1, t2 must be one sparse and one dense!' return sdd_bmm_diag_torch(t1, t2)
[docs]def sdsbmm(t1, t2): """ Perform bmm for sparse x dense -> sparse. This is a wrapper function and does not support gradient. """ assert (type(t1) == list) != (type(t2) == list) or t1.is_sparse != t2.is_sparse, \ 't1, t2 must be one sparse and one dense!' if type(t1) == list or t1.is_sparse: result = sds_bmm_torch(t1, t2) else: result = sds_bmm_torch(t2.transpose(1, 2), t1.transpose(1, 2)).transpose(1, 2) return result
[docs]def sssbmm_diag(m1, m2): """ Perform bmm and diagonal for sparse x sparse -> sparse. This is a wrapper function and does not support gradient. """ if (type(m1) == list and type(m1[0]) == torch.Tensor) or type(m1) == torch.Tensor: m1 = torch2ssp(m1) if (type(m2) == list and type(m2[0]) == torch.Tensor) or type(m2) == torch.Tensor: m2 = torch2ssp(m2) return sss_bmm_diag_spp(m1, m2)
''' Torch API Functions '''
[docs]class SparseDenseDenseBMM(Function): """ bmm (Batch Matrix Matrix) for sparse x dense -> dense. with s_t1.shape = (b, x, s), d_t2.shape = (b, s, y), the output shape is (b, x, y) This is a work around utilizing torch.mm for sparse x dense -> dense. Forward and backward options are implemented. """
[docs] @staticmethod def forward(ctx, t1, t2): """ :param t1: tensor 1 :param t2: tensor 2 :return: bmm result in dense """ ctx.save_for_backward(t1, t2) assert t1.is_sparse != t2.is_sparse, 't1, t2 must be one sparse and one dense!' if t1.is_sparse: result = sdd_bmm_torch(t1, t2) else: result = sdd_bmm_torch(t2.transpose(1, 2), t1.transpose(1, 2)).transpose(1, 2) return result
[docs] @staticmethod def backward(ctx, dm): s_t1, d_t2 = ctx.saved_tensors dt1 = dt2 = None if ctx.needs_input_grad[0]: dt1 = torch.bmm(dm, d_t2.transpose(1, 2)) dt1 = dense_to_sparse(dt1) if ctx.needs_input_grad[1]: dt2 = sdd_bmm_torch(s_t1.transpose(1, 2), dm) return dt1, dt2
[docs]def sdd_bmm_torch(s_t1, d_t2): """ bmm (Batch Matrix Matrix) for sparse x dense -> dense. This function itself doesn't support gradient. with s_t1.shape = (b, x, s), d_t2.shape = (b, s, y), the output shape is (b, x, y) This is a work around utilizing torch.mm for sparse x dense -> dense :param s_t1: sparse tensor 1 :param d_t2: dense tensor 2 :return: bmm result in dense """ device = s_t1.device batch_num = s_t1.shape[0] x = s_t1.shape[1] y = d_t2.shape[2] assert s_t1.shape[0] == d_t2.shape[0], 'Batch size mismatch.' assert s_t1.shape[2] == d_t2.shape[1], 'Matrix shape mismatch.' outp = torch.empty(batch_num, x, y, dtype=s_t1.dtype, device=device) for b in range(batch_num): _s_t1 = get_batches(s_t1, b) torch.mm(_s_t1, d_t2[b, :, :], out=outp[b, :, :]) return outp
[docs]def sdd_bmm_diag_torch(t1, t2): """ Perform bmm and diagonal for sparse x dense -> dense. The diagonalized result is returned in vector tensor. With s_t1.shape = (b, x, s), d_t2.shape = (b, s, x), the output shape is (b, x). This method avoids a temporal (b, x, x) for memory efficiency. :param t1: tensor 1 :param t2: tensor 2 :return: bmm_diag result in dense """ assert t1.shape[0] == t2.shape[0], 'Batch size mismatch.' assert t1.shape[2] == t2.shape[1] and t1.shape[1] == t2.shape[2], 'Matrix shape mismatch.' if t1.is_sparse: d_t1 = t1.transpose(1, 2).to_dense() outp = torch.sum(d_t1.mul_(t2), dim=1) else: d_t2 = t2.transpose(1, 2).to_dense() outp = torch.sum(d_t2.mul_(t1), dim=2) return outp
[docs]def sds_bmm_torch(s_t1, d_t2): """ bmm (Batch Matrix Matrix) for sparse x dense -> sparse. This function doesn't support gradient. And sparse tensors cannot accept gradient due to the limitation of torch implementation. with s_t1.shape = (b, x, s), d_t2.shape = (b, s, y), the output shape is (b, x, y) This is a work around utilizing torch.smm for sparse x dense -> sparse :param s_t1: sparse tensor 1 (in list, representing batches) :param d_t2: dense tensor 2 :return: bmm result in sparse (in list, representing batches) """ device = d_t2.device assert type(s_t1) == list batch_num = len(s_t1) assert batch_num == d_t2.shape[0], 'Batch size mismatch.' outp = [] for b in range(batch_num): # force cpu _s_t1 = s_t1[b].cpu() _d_t2 = d_t2[b].cpu() assert _s_t1.shape[1] == _d_t2.shape[0], 'Matrix shape mismatch.' _outp = torch.smm(_s_t1, _d_t2) # CUDA version of smm is not implemented outp.append(_outp) return outp
[docs]def bilinear_diag_torch(s_t1: CSRMatrix3d, d_t2: torch.Tensor, s_t3: CSCMatrix3d, device=None): """ Bilinear and diagonal in sequence, for diagonal(sparse x dense x sparse) -> dense vector. with s_t1.shape = (b, x, y), d_t2.shape = (b, y, y), d_t3.shape = (b, y, x), the output shape is (b, x). In this function, two sparse tensors (s1 and s3) are represented in CSR and CSC format to guarantee efficient computation. The main operation is implemented in a custom C++ extension, and will be ~1000x faster if CUDA is available. :param s_t1: CSR matrix 1 :param d_t2: dense tensor 2 :param s_t3: CSC matrix 3 :param device: device. If not specified, it will be the same as input. :return: returned dense vector """ if device is None: device = d_t2.device #dtype = d_t2.dtype batch_num = s_t1.shape[0] xlen = s_t1.shape[1] assert s_t1.shape[0] == d_t2.shape[0] == s_t3.shape[0], 'Batch size mismatch.' assert s_t1.shape[1] == s_t3.shape[2], 'Sparse matrix 1 & 3 shape mismatch.' assert s_t1.shape[2] == d_t2.shape[1] == d_t2.shape[2] == s_t3.shape[1], 'Matrix size mismatch.' ''' s_t1_input = [[], [], []] s_t3_input = [[], [], []] t1_indptr_offset = 0 t3_indptr_offset = 0 for b in range(batch_num): _s_t1 = s_t1[b].tocsc() s_t1_input[0].append(_s_t1.indices) s_t1_input[1].append(_s_t1.indptr[:-1] + t1_indptr_offset) s_t1_input[2].append(_s_t1.data) t1_indptr_offset += _s_t1.indptr[-1] _s_t3 = s_t3[b].tocsc() s_t3_input[0].append(_s_t3.indices) s_t3_input[1].append(_s_t3.indptr[:-1] + t3_indptr_offset) s_t3_input[2].append(_s_t3.data) t3_indptr_offset += _s_t3.indptr[-1] s_t1_input[1].append(np.array([t1_indptr_offset])) s_t3_input[1].append(np.array([t3_indptr_offset])) for input in (s_t1_input, s_t3_input): for idx in range(len(input)): if idx == 2: # data _dtype = dtype else: # indices & indptr _dtype = torch.int64 input[idx] = torch.tensor(np.concatenate(input[idx]), dtype=_dtype, device=device) ''' outp = bilinear_diag.bilinear_diag(*s_t1.as_list(), d_t2, *s_t3.as_list(), batch_num, xlen) return outp.to(device)
[docs]def dense_to_sparse(d_t): """ Convert a dense tensor to a sparse one. :param d_t: dense tensor :return: sparse tensor """ dtype = d_t.dtype device = d_t.device req_grad = d_t.requires_grad indices = torch.nonzero(d_t) if len(indices.shape) == 0: # if all elements are zeros return torch.sparse_coo_tensor([], [], d_t.shape, dtype=dtype, device=device, requires_grad=req_grad) indices = indices.t() values = d_t[tuple(indices[i] for i in range(indices.shape[0]))] return torch.sparse_coo_tensor(indices, values, d_t.size(), dtype=dtype, device=device, requires_grad=req_grad)
[docs]def get_batches(s_t, b=None, device=None): """ Get batches from a 3d sparse tensor. :param s_t: sparse tensor :param b: if None, return all batches in a list; else, return a specific batch :param device: device. If None, it will be the same as input :return: sparse tensor or list of sparse tensors """ if device is None: device = s_t.device coo = s_t._indices() data = s_t._values() if b is not None: idx = (coo[0, :] == b).nonzero() _coo = coo[1:3, idx].view(2, -1) _data = data[idx].view(-1) outp = torch.sparse_coo_tensor(_coo, _data, s_t.shape[1:3], dtype=_data.dtype, device=device) else: batch_num = s_t.shape[0] outp = [] for b in range(batch_num): idx = (coo[0, :] == b).nonzero() _coo = coo[1:3, idx].view(2, -1) _data = data[idx].view(-1) outp.append(torch.sparse_coo_tensor(_coo, _data, s_t.shape[1:3], dtype=_data.dtype, device=device)) return outp
[docs]def slicing_torch(s_t, slice, preserve_dim=False): """ A slicing function for torch sparse tensors. :param s_t: input sparse tensor :param slice: tensor containing indices, -1 stands for all. For example, (1, -1) returns the second row of a 2d tensor. :param preserve_dim: If True, the dimension of the original tensor will be preserved, i.e. 1 will be padded for those removed dimensions. :return: sliced sparse tensor """ device = s_t.device dim = slice.shape[0] assert len(s_t.shape) == dim coo = s_t._indices() data = s_t._values() idx_flag = torch.ones(coo.shape[1], dtype=torch.uint8, device=device) for i in range(dim): s = slice[i] if s == -1: continue _idx_flag = (coo[i, :] == s).view(-1) idx_flag.mul_(_idx_flag) idx = idx_flag.nonzero().view(-1) if not preserve_dim: dim_flag = (slice == -1).nonzero().view(-1) if dim_flag.numel() == 0: coo = torch.tensor([[0]], dtype=coo.dtype, device=device) shape = torch.Size([1]) else: coo = coo[:, idx] coo = coo[dim_flag, :] shape = torch.Size(torch.tensor(s_t.shape)[dim_flag]) else: coo = coo[:, idx] coo.mul_((slice == -1).type(coo.dtype).view(-1, 1)) _dtype = torch.int32 shape = torch.Size(torch.tensor(s_t.shape, dtype=_dtype, device=device) * (slice == -1).type(_dtype) + torch.ones(len(s_t.shape), dtype=_dtype, device=device) * (slice != -1).type(_dtype)) data = data[idx] return torch.sparse_coo_tensor(coo, data, shape, dtype=s_t.dtype, device=s_t.device)
''' scipy.sparse API Functions '''
[docs]def sss_bmm_diag_spp(s_m1, s_m2): """ bmm (Batch Matrix Matrix) for sparse x sparse -> sparse. The diagonalized result is returned in vector tensor. with s_m1.shape = (b, x, s), s_m2.shape = (b, s, x), the output shape is (b, x) This function doesn't support gradient. :param s_m1: sparse matrix 1 :param s_m2: sparse matrix 2 :return: result in sparse vector """ if type(s_m1) != list: s_m1 = [s_m1] if type(s_m2) != list: s_m2 = [s_m2] assert len(s_m1) == len(s_m2), 'Batch size mismatch.' outp = [] for _m1, _m2 in zip(s_m1, s_m2): assert _m1.shape[1] == _m2.shape[0] and _m1.shape[0] == _m2.shape[1], 'Matrix shape mismatch.' outp.append(_m1.dot(_m2).diagonal().tocoo()) return outp
''' Conversion Functions '''
[docs]def ssp2torch(M, batch='dim', dtype=torch.float32, device=None): """ Convert scipy.sparse matrix to torch sparse matrix. Since scipy.sparse has a dimension limit of 2, list of matrices is supported for batches. :param M: input scipy.sparse matrix :param batch: the type that represent batches in the output. If batch='list', tensors are 2d and stored in list. If batch='dim', tensors are 3d ane the first dimension represents batch size. :param dtype: output data type :param device: device :return: output torch sparse matrix """ assert batch in ('list', 'dim') if type(M) != list: M = [M] batch_num = len(M) if batch == 'list': outp = [] for i in range(batch_num): _M = M[i] _M = _M.tocoo() coo = np.array([_M.row, _M.col]) data = _M.data outp.append(torch.sparse_coo_tensor(coo, data, _M.shape, dtype=dtype, device=device)) else: batch, row, col, data = np.array([]), np.array([]), np.array([]), np.array([]) for i in range(batch_num): _M = M[i] _M = _M.tocoo() batch = np.append(batch, np.ones(_M.nnz) * i) row = np.append(row, _M.row) col = np.append(col, _M.col) data = np.append(data, _M.data) coo = np.array([batch, row, col]) outp = torch.sparse_coo_tensor(coo, data, torch.Size([batch_num] + list(_M.shape)), dtype=dtype, device=device) return outp
[docs]def torch2ssp(M): """ Convert torch sparse matrix to scipy.sparse matrix. Since scipy.sparse has a dimension limit of 2, batches are represented in list in the output. :param M: input torch sparse matrix :return: output scipy.sparse matrix """ if type(M) == list: batch_num = len(M) outp = [] for b in range(batch_num): _M = M[b] _coo = _M._indices() _data = _M._values() outp.append(ssp.coo_matrix((_data, _coo), _M.shape)) else: coo = M._indices() data = M._values() batch_num = M.shape[0] if len(M.shape) == 2: outp = ssp.coo_matrix((data, coo), M.shape) else: assert len(M.shape) == 3 outp = [] for b in range(batch_num): idx = (coo[0, :] == b).nonzero() _coo = coo[1:3, idx].view(2, -1) _data = data[idx].view(-1) outp.append(ssp.coo_matrix((_data, _coo), M.shape[1:3])) return outp
[docs]def recover_ssp(t_dict): """ Recover scipy.sparse coo_matrix from a dictionary containing row, col and data tensors. :param t_dict: containing keys 'row', 'col', 'data', each corresponds to a bxn tensor 'shape', containing the MxN shape of each tensor :return: list of scipy.sparse matrix. list indices represent batches. """ batch_size = t_dict['row'].shape[0] np_dict = {key: t_dict[key].numpy() for key in t_dict} ss = [] max_shape = np.zeros((2,), dtype=np.int) for b in range(batch_size): shape = np_dict['shape'][b].astype(np.int) max_shape[0] = max(shape[0], max_shape[0]) max_shape[1] = max(shape[1], max_shape[1]) for b in range(batch_size): data = np_dict['data'][b] row = np_dict['row'][b] col = np_dict['col'][b] _ss = ssp.coo_matrix((data, (row, col)), shape=max_shape) ss.append(_ss) return ss
if __name__ == '__main__': t = torch.tensor([[[ 1, 2, 3, 4], [11, 22, 33, 44]]]) t = dense_to_sparse(t) s = slicing_torch(t, torch.tensor((0, 0, 1)), preserve_dim=True) print(s.to_dense()) from torch.autograd import gradcheck input = (dense_to_sparse(torch.randn(1, 20, 30, dtype=torch.double, requires_grad=True)), torch.randn(1, 30, 40, dtype=torch.double, requires_grad=True)) test = gradcheck(sbmm, input, eps=1e-6, atol=1e-4) print(test)