src.utils.sparse.bilinear_diag_torch

src.utils.sparse.bilinear_diag_torch(s_t1: src.sparse_torch.csx_matrix.CSRMatrix3d, d_t2: torch.Tensor, s_t3: src.sparse_torch.csx_matrix.CSCMatrix3d, device=None)[source]

Bilinear and diagonal in sequence, for diagonal(sparse x dense x sparse) -> dense vector. with s_t1.shape = (b, x, y), d_t2.shape = (b, y, y), d_t3.shape = (b, y, x), the output shape is (b, x). In this function, two sparse tensors (s1 and s3) are represented in CSR and CSC format to guarantee efficient computation. The main operation is implemented in a custom C++ extension, and will be ~1000x faster if CUDA is available. :param s_t1: CSR matrix 1 :param d_t2: dense tensor 2 :param s_t3: CSC matrix 3 :param device: device. If not specified, it will be the same as input. :return: returned dense vector