src.utils.sparse.sdd_bmm_torch

src.utils.sparse.sdd_bmm_torch(s_t1, d_t2)[source]

bmm (Batch Matrix Matrix) for sparse x dense -> dense. This function itself doesn’t support gradient. with s_t1.shape = (b, x, s), d_t2.shape = (b, s, y), the output shape is (b, x, y) This is a work around utilizing torch.mm for sparse x dense -> dense :param s_t1: sparse tensor 1 :param d_t2: dense tensor 2 :return: bmm result in dense