src.utils.sparse.sds_bmm_torch

src.utils.sparse.sds_bmm_torch(s_t1, d_t2)[source]

bmm (Batch Matrix Matrix) for sparse x dense -> sparse. This function doesn’t support gradient. And sparse tensors cannot accept gradient due to the limitation of torch implementation. with s_t1.shape = (b, x, s), d_t2.shape = (b, s, y), the output shape is (b, x, y) This is a work around utilizing torch.smm for sparse x dense -> sparse :param s_t1: sparse tensor 1 (in list, representing batches) :param d_t2: dense tensor 2 :return: bmm result in sparse (in list, representing batches)