Source code for src.parallel.data_parallel
import torch.nn as nn
from .scatter_gather import scatter_kwargs, gather
[docs]class DataParallel(nn.DataParallel):
"""
DataParallel wrapper with customized scatter/gather functions
"""
def __init__(self, *args, **kwargs):
super(DataParallel, self).__init__(*args, **kwargs)
[docs] def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
[docs] def gather(self, outputs, output_device):
return gather(outputs, output_device, dim=self.dim)