Source code for src.parallel.data_parallel

import torch.nn as nn
from .scatter_gather import scatter_kwargs, gather


[docs]class DataParallel(nn.DataParallel): """ DataParallel wrapper with customized scatter/gather functions """ def __init__(self, *args, **kwargs): super(DataParallel, self).__init__(*args, **kwargs)
[docs] def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
[docs] def gather(self, outputs, output_device): return gather(outputs, output_device, dim=self.dim)