src.utils.sparse.sdd_bmm_diag_torch

src.utils.sparse.sdd_bmm_diag_torch(t1, t2)[source]

Perform bmm and diagonal for sparse x dense -> dense. The diagonalized result is returned in vector tensor. With s_t1.shape = (b, x, s), d_t2.shape = (b, s, x), the output shape is (b, x). This method avoids a temporal (b, x, x) for memory efficiency. :param t1: tensor 1 :param t2: tensor 2 :return: bmm_diag result in dense