Source code for src.utils.data_to_cuda

import torch
from src.sparse_torch.csx_matrix import CSRMatrix3d, CSCMatrix3d
import torch_geometric as pyg

[docs]def data_to_cuda(inputs): """ Call cuda() on all tensor elements in inputs :param inputs: input list/dictionary :return: identical to inputs while all its elements are on cuda """ if type(inputs) is list: for i, x in enumerate(inputs): inputs[i] = data_to_cuda(x) elif type(inputs) is tuple: inputs = list(inputs) for i, x in enumerate(inputs): inputs[i] = data_to_cuda(x) elif type(inputs) is dict: for key in inputs: inputs[key] = data_to_cuda(inputs[key]) elif type(inputs) in [str, int, float]: inputs = inputs elif type(inputs) in [torch.Tensor, CSRMatrix3d, CSCMatrix3d]: inputs = inputs.cuda() else: try: pyg_datatypes = [pyg.data.Data, pyg.data.Batch, pyg.data.batch.DataBatch] except AttributeError: pyg_datatypes = [pyg.data.Data, pyg.data.Batch] if type(inputs) in pyg_datatypes: inputs = inputs.to('cuda') else: raise TypeError('Unknown type of inputs: {}'.format(type(inputs))) return inputs