Source code for src.sparse_torch.csx_matrix

import sys
import os
import torch
import numpy as np
import scipy.sparse as ssp

if 'SPHINX' not in os.environ:
    from torch.utils.cpp_extension import load

    sparse_dot = load(name='sparse_dot',
                      sources=['src/extension/sparse_dot/sparse_dot.cpp',
                               'src/extension/sparse_dot/csr_dot_csc_cuda.cu',
                               'src/extension/sparse_dot/csr_dot_diag_cuda.cu'],
                      extra_include_paths=[
                          '/usr/include/python{}.{}/'.format(sys.version_info.major, sys.version_info.minor)]
                      )


[docs]class CSXMatrix3d: def __init__(self, inp, shape, device=None): def from_ssp(inp_s: list, shape, device=None, sptype=self.sptype): """ Load data from list of scipy.sparse matrix :param inp_s: list of input scipy.sparse matrix :param shape: output matrix shape. :param device: device. If not specified, it will be the same as input. :param sptype: sparse matrix type. Should be 'csr' or 'csc' """ assert len(shape) == 3, 'Only 3-dimensional tensor (bxhxw) is supported' batch_num = shape[0] indices = [] indptr = [] data = [] indptr_offset = 0 for b in range(batch_num): if sptype == 'csc': inp_s[b].eliminate_zeros() sp = inp_s[b].tocsc().astype(dtype=inp_s[b].dtype) elif sptype == 'csr': inp_s[b].eliminate_zeros() sp = inp_s[b].tocsr().astype(dtype=inp_s[b].dtype) else: raise ValueError('Sparse type not understood {}'.format(sptype)) indices.append(sp.indices) indptr.append(sp.indptr[:-1] + indptr_offset) data.append(sp.data) indptr_offset += sp.indptr[-1] indptr.append(np.array([indptr_offset])) return from_tensors(*[np.concatenate(x) for x in (indices, indptr, data)], shape=shape, device=device) def from_tensors(ind: torch.Tensor or np.ndarray, indp: torch.Tensor or np.ndarray, data: torch.Tensor or np.ndarray, shape, device=None): """ Load data from raw input tensors/arrays. :param ind: indices array/tensor :param indp: indptr array/tensor :param data: data array/tensor :param shape: output matrix shape. :param device: device. Optional :return: indices(Tensor), indptr(Tensor), data(Tensor), shape(tuple) """ if type(ind) == torch.Tensor and device is None: device = ind.device if type(ind) is torch.Tensor: indices_t = ind.to(torch.int64).to(device) else: indices_t = torch.tensor(ind, dtype=torch.int64, device=device) if type(indp) is torch.Tensor: indptr_t = indp.to(torch.int64).to(device) else: indptr_t = torch.tensor(indp, dtype=torch.int64, device=device) if type(data) is torch.Tensor: data_t = data.to(dtype=data.dtype).to(device) else: data_t = torch.tensor(data, device=device) return indices_t, indptr_t, data_t, tuple(shape) if type(inp) == list and isinstance(inp[0], ssp.spmatrix): self.indices, self.indptr, self.data, self.shape = from_ssp(inp, shape, device) elif type(inp) == list: self.indices, self.indptr, self.data, self.shape = from_tensors(*inp, shape, device) else: raise ValueError('Data type {} not understood.'.format(type(inp))) def __getitem__(self, item): """ Get item through slicing. The slicing is only supported on the batch dimention :param item: index or slice :return: new sparse matrix """ if isinstance(item, int): indices, indptr, data = self.get_batch(item) return self.__class__([indices, indptr, data], shape=[1] + list(self.shape[1:3])) elif isinstance(item, slice): indices = [] indptr = [] data = [] indptr_offset = int(0) batch_iter = range(item.start, item.stop, item.step if item.step is not None else 1) for b in batch_iter: _indices, _indptr, _data = self.get_batch(b) indices.append(_indices) indptr.append(_indptr[:-1] + indptr_offset) data.append(_data) indptr_offset = indptr_offset + _indptr[-1] assert isinstance(indptr_offset, torch.Tensor) indptr.append(indptr_offset.view(1)) indices = torch.cat(indices) indptr = torch.cat(indptr) data = torch.cat(data) return self.__class__([indices, indptr, data], shape=[len(batch_iter)] + list(self.shape[1:3])) else: raise ValueError('Index type {} not supported.'.format(type(item))) def __len__(self): return self.shape[0] @property def device(self): return self.indices.device @property def sptype(self): raise NotImplementedError
[docs] def transpose(self, keep_type=False): raise NotImplementedError
[docs] def to(self, tgt): """ Compatible to torch.Tensor.to() :param tgt: target, can be torch.device or torch.dtype :return: a new instance """ if isinstance(tgt, torch.device): return self.__class__([x.to(tgt) for x in [self.indices, self.indptr, self.data]], self.shape) elif isinstance(tgt, torch.dtype): return self.__class__([self.indices, self.indptr, self.data.to(tgt)], self.shape) else: raise ValueError('Data type not understood.')
[docs] def cuda(self): """ Compatible to torch.Tensor.cuda() :return: a new instance on CUDA """ return self.__class__([x.cuda() for x in [self.indices, self.indptr, self.data]], self.shape)
[docs] def cpu(self): """ Compatible to torch.Tensor.cpu() :return: a new instance on CPU """ return self.__class__([x.cpu() for x in [self.indices, self.indptr, self.data]], self.shape)
[docs] def numpy(self): """ Return dense numpy array. :return: dense numpy array. """ ret = [x.toarray() for x in self.as_ssp()] ret = np.stack(ret, axis=0) return ret
[docs] def as_list(self, mask=None): """ Return [indices, indptr, data] in a list. :param mask: Optional. It should be an iterable containing 3 items, each indicating its corresponding attribute shall be masked out or not. :return: [indices, indptr, data] * mask """ attrs = [self.indices, self.indptr, self.data] if mask is not None: ret = [] for m, a in zip(mask, attrs): if m: ret.append(a) else: ret = attrs return ret
[docs] def as_ssp(self): """ Return scipy.sparse matrix. :return: list of scipy.sparse matrix """ ret = [] for b in range(self.shape[0]): indice, indptr, data = self.get_batch(b) construct_func = ssp.csr_matrix if self.sptype == 'csr' else ssp.csc_matrix ret.append( construct_func( (data.cpu().to(dtype=data.dtype).numpy(), indice.cpu().numpy(), indptr.cpu().numpy()), shape=self.shape[1:3] ) ) return ret
[docs] def as_sparse_torch(self): coo = torch.zeros(3, self.data.shape[0], dtype=torch.long, device=self.device) for b in range(self.shape[0]): if self.sptype == 'csr': start_ptr = b * self.shape[1] end_ptr = (b + 1) * self.shape[1] + 1 compressed_len = self.shape[1] compressed_idx = 1 elif self.sptype == 'csc': start_ptr = b * self.shape[2] end_ptr = (b + 1) * self.shape[2] + 1 compressed_len = self.shape[2] compressed_idx = 2 else: raise ValueError('Data type not understood.') indptr = self.indptr[start_ptr: end_ptr] coo[0, indptr[0]:indptr[-1]] = b for i in range(compressed_len): coo[compressed_idx, indptr[i]:indptr[i+1]] = i if self.sptype == 'csr': coo[2, :] = self.indices else: coo[1, :] = self.indices return torch.sparse.FloatTensor(coo, self.data, self.shape)
[docs] def get_batch(self, item): """ Get a certain batch in tuple (indices, indptr, data) :param item: batch index :return: (indices, indptr, data) """ if type(item) != int: raise IndexError('Only int indices is currently supported.') if self.sptype == 'csr': start_idx = item * self.shape[1] end_idx = (item + 1) * self.shape[1] + 1 elif self.sptype == 'csc': start_idx = item * self.shape[2] end_idx = (item + 1) * self.shape[2] + 1 else: raise ValueError('Data type not understood.') indptr = self.indptr[start_idx: end_idx].clone() indice = self.indices[indptr[0]: indptr[-1]].clone() data = self.data[indptr[0]: indptr[-1]].clone() indptr = indptr - indptr[0] return indice, indptr, data
[docs] def shape_eq(self, other): ret = True for s_shape, o_shape in zip(self.shape, other.shape): if s_shape != o_shape: ret = False break return ret
[docs]class CSCMatrix3d(CSXMatrix3d): def __init__(self, inp, shape=None, device=None): if type(inp) == list and isinstance(inp[0], ssp.spmatrix): max_shape = [0, 0] for s in inp: max_shape[0] = max(max_shape[0], s.shape[0]) max_shape[1] = max(max_shape[1], s.shape[1]) if shape is None: shape = tuple([len(inp)] + max_shape) else: assert shape[0] == len(inp) assert shape[1] <= max_shape[0] assert shape[2] <= max_shape[1] elif type(inp) == list: assert shape is not None batch = shape[0] row = _max(inp[0]) col = (len(inp[1]) - 1) // batch assert shape[1] >= row assert shape[2] == col super(CSCMatrix3d, self).__init__(inp, shape, device) @property def sptype(self): return 'csc'
[docs] def transpose(self, keep_type=False): if not keep_type: shape_t = list(self.shape) tmp = shape_t[1] shape_t[1] = shape_t[2] shape_t[2] = tmp return CSRMatrix3d(self.as_list(), shape=shape_t, device=self.device) else: coo = [] for sp in self.as_ssp(): coo.append(sp.transpose().tocoo().astype(sp.dtype)) return CSCMatrix3d(coo, device=self.device)
[docs] def Tdot(self, other, *args, **kwargs): """ The dot result of a TRANSPOSED CSC matrix and another CSC matrix. This is equivalent to CSR dot CSC. :param other: second CSC matrix :return: dot product in a new CSR matrix """ t_csr = self.transpose() return dot(t_csr, other, *args, **kwargs)
[docs]class CSRMatrix3d(CSXMatrix3d): def __init__(self, inp, shape=None, device=None): if type(inp) == list and isinstance(inp[0], ssp.spmatrix): max_shape = [0, 0] for s in inp: max_shape[0] = max(max_shape[0], s.shape[0]) max_shape[1] = max(max_shape[1], s.shape[1]) if shape is None: shape = tuple([len(inp)] + max_shape) else: assert shape[0] == len(inp) assert shape[1] <= max_shape[0] assert shape[2] <= max_shape[1] elif type(inp) == list: assert shape is not None batch = shape[0] row = (len(inp[1]) - 1) // batch col = _max(inp[0]) assert shape[1] == row assert shape[2] >= col super(CSRMatrix3d, self).__init__(inp, shape, device) @property def sptype(self): return 'csr'
[docs] def transpose(self, keep_type=False): if not keep_type: shape_t = list(self.shape) tmp = shape_t[1] shape_t[1] = shape_t[2] shape_t[2] = tmp return CSCMatrix3d(self.as_list(), shape=shape_t, device=self.device) else: coo = [] for sp in self.as_ssp(): coo.append(sp.transpose().tocoo().astype(sp.dtype)) return CSRMatrix3d(coo, device=self.device)
[docs] def dot(self, other, *args, **kwargs): """ Dot product of this CSR matrix and a CSC matrix. :param other: CSC matrix. :return: dot product in CSR matrix """ return dot(self, other, *args, **kwargs)
[docs] def dotdiag(self, other): """ Dot product of this CSR matrix and a diagonal matrix from a vector. :param other: input vector. :return: dot product in CSR matrix """ assert self.shape[0] == other.shape[0], 'Batch size mismatch' assert self.shape[2] == other.shape[1], 'Matrix shape mismatch' batch_size = self.shape[0] out_h = self.shape[1] out_w = self.shape[2] result = sparse_dot.csr_dot_diag(*self.as_list(), other, batch_size, out_h, out_w) ret = CSRMatrix3d(result, shape=self.shape) ''' indptr = self.indptr.clone() indice = self.indices.clone() data = self.data.clone() for b in range(batch_size): start_idx = b * self.shape[1] end_idx = (b + 1) * self.shape[1] + 1 indp_b = indptr[start_idx: end_idx] indx_b = indice[indp_b[0]: indp_b[-1]] data_b = data[indp_b[0]: indp_b[-1]] for j in range(self.shape[2]): data_b[indx_b == j] *= other[b, j] ret = CSRMatrix3d([indice, indptr, data], self.shape) ''' return ret
[docs]def dot(csr: CSRMatrix3d, csc: CSCMatrix3d, dense=False): """ Compute the dot product of one CSR matrix and one CSC matrix. The result will be returned in a new CSR or dense matrix. Note that CUDA implementations do not work when dense=False. :param csr: fist input CSR matrix :param csc: second input CSC matrix :param dense: output matrix in dense format :return: dot result in new csr matrix (dense=False) or dot result in dense matrix (dense=True) """ assert type(csr) == CSRMatrix3d assert type(csc) == CSCMatrix3d assert csr.shape[0] == csc.shape[0], 'Batch size mismatch' batch_num = csr.shape[0] assert csr.shape[2] == csc.shape[1], 'Matrix size mismatch' out_h = csr.shape[1] out_w = csc.shape[2] if csr.indptr.device == torch.device('cpu'): new_indices, new_indptr, new_data = \ sparse_dot.csr_dot_csc(*csr.as_list(), *csc.as_list(), batch_num, out_h, out_w) ret = CSRMatrix3d([new_indices, new_indptr, new_data], shape=(batch_num, out_h, out_w)) if dense: ret = ret.numpy() else: if not dense: raise RuntimeWarning('Sparse dot product result in CUDA is not implemented.') ret = sparse_dot.csr_dot_csc_dense_cuda(*csr.as_list(), *csc.as_list(), batch_num, out_h, out_w) return ret
[docs]def concatenate(*mats: CSXMatrix3d, device=None): """ Concatenate multiple sparse matrix in first (batch) dimension. :param mats: sequence of input matrix :return: concatenated matrix """ device = mats[0].device if device is None else device mat_type = type(mats[0]) mat_h = mats[0].shape[1] mat_w = mats[0].shape[2] batch_size = 0 indptr_offset = 0 indices = [] indptr = [] data = [] for mat in mats: assert type(mat) == mat_type, 'Matrix type inconsistent' assert mat.shape[1] == mat_h, 'Matrix shape inconsistent in dimension 1' assert mat.shape[2] == mat_w, 'Matrix shape inconsistent in dimension 2' indices.append(mat.indices.clone().to(device)) indptr.append(mat.indptr[:-1].clone().to(device) + indptr_offset) data.append(mat.data.clone().to(device)) indptr_offset += mat.indptr[-1].to(device) indptr_offset = indptr_offset.to(device) batch_size += mat.shape[0] indptr.append(indptr_offset.view(1)) indices = torch.cat(indices) indptr = torch.cat(indptr) data = torch.cat(data) return mat_type([indices, indptr, data], shape=(batch_size, mat_h, mat_w))
def _max(inp, *args, **kwargs): if type(inp) == np.ndarray: return np.max(inp, *args, **kwargs) elif type(inp) == torch.Tensor: return torch.max(inp, *args, **kwargs) else: raise ValueError('Data type {} not understood.'.format(type(inp)))