src.utils.sparse.sss_bmm_diag_spp

src.utils.sparse.sss_bmm_diag_spp(s_m1, s_m2)[source]

bmm (Batch Matrix Matrix) for sparse x sparse -> sparse. The diagonalized result is returned in vector tensor. with s_m1.shape = (b, x, s), s_m2.shape = (b, s, x), the output shape is (b, x) This function doesn’t support gradient. :param s_m1: sparse matrix 1 :param s_m2: sparse matrix 2 :return: result in sparse vector