src.utils.sparse.get_batches

src.utils.sparse.get_batches(s_t, b=None, device=None)[source]

Get batches from a 3d sparse tensor. :param s_t: sparse tensor :param b: if None, return all batches in a list; else, return a specific batch :param device: device. If None, it will be the same as input :return: sparse tensor or list of sparse tensors