PyTorch parallel and distributed DataParallel principle, source code analysis, examples and actual combat

Brief overview

   the data parallel classes officially provided by pytorch are:

torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

   when a model is given, the main function is to divide the input data into specified devices according to the dimension of batch. Other objects are copied to each device. In the process of forward propagation, the module is copied to each device, and each copied copy processes part of the input data. In the back propagation process, the gradient of each copy module is aggregated to the original module for calculation (generally the 0th GPU).

And one thing to note here is that the official recommendation here is to use DistributedDataParallel, because DistributedDataParallel uses multi-process mode, while DataParallel uses multi-threaded mode. If you are using distributed DataParallel, you need to use torch distributed. Launch to the launch program, refer to Distributed Communication Package - Torch.Distributed.

   the size of batch size must be greater than the number of GPUs. In my practice, the size of batch size is generally set as a multiple of the number of GPU blocks. When the data is allocated to different machines, the data transferred into the module can also be transferred into dataparallel (module type after parallel). However, by default, tensor is allocated to different machines according to dim=0. The data of tuple, list and dict types are shallow copied to different GPUs, and other types of data will be allocated to different processes.

   before calling DataParallel, the module must have its own parameters (parameters that can get the model) and a buffer on the specified GPU (otherwise a memory error will be reported).

In the forward propagation process, the module is copied to each device, so any updates in the forward propagation process will be lost. For example, if the module has a counter attribute, 1 will be added during each frontline propagation, and it will remain in the initial value state, because the update is on the replica, but the replica will be destroyed after frontline propagation. However, in DataParallel, the copy on device[0] shares its parameters and memory data with parallel modules, so the updated data on device[0] will be recorded.

The returned result is a summary of the data from each device. The default is the summary in dim 0 dimension. Therefore, we need to pay attention to this when processing RNN timing data. My recurrent network doesn't work with data parallelism

torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

  torch. nn. The parameters of dataparallel() function mainly include module and device_ids´╝îoutput_device these three.

  1. Module is a module that requires parallelism.
  2. device_ids is a list and defaults to all operable devices.
  3. output_device is the specified GPU that needs to output summary, and the default is device_ids[0] number.

  simple examples are:

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU

Source code analysis

  data_ parallel. The source address of Py is:

  source code comments

import operator
import torch
import warnings
from itertools import chain
from ..modules import Module
from .scatter_gather import scatter_kwargs, gather
from .replicate import replicate
from .parallel_apply import parallel_apply
from torch._utils import (

def _check_balance(device_ids):
    imbalance_warn = """
    There is an imbalance between your GPUs. You may want to exclude GPU {} which
    has less than 75% of the memory or cores of GPU {}. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable."""
    device_ids = [_get_device_index(x, True) for x in device_ids]
    dev_props = _get_devices_properties(device_ids)

    def warn_imbalance(get_prop):
        values = [get_prop(props) for props in dev_props]
        min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
        max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
        if min_val / max_val < 0.75:
            warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
            return True
        return False

    if warn_imbalance(lambda props: props.total_memory):
    if warn_imbalance(lambda props: props.multi_processor_count):

DataParallel class initialization:

class DataParallel(Module):
    # TODO: update notes/cuda.rst when this class handles 8+ GPUs well

    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()
		# Call torch cuda. is_ Available () determines whether to return "CUDA" or None.
        device_type = _get_available_device_type() 
        if device_type is None: # Check for GPU
        	# If there is no GPU, the module cannot be assigned in parallel, and the device id is set to null
            self.module = module
            self.device_ids = []

        if device_ids is None: # If no GPU is specified, all available GPUs are used by default
        	# Get all available device ID s as a list.
            device_ids = _get_all_device_indices()

        if output_device is None: # Determine whether the output device is specified
            output_device = device_ids[0] # Defaults to the first of the specified devices

        self.dim = dim
        self.module = module # self. The module passed in is the module.
        self.device_ids = [_get_device_index(x, True) for x in device_ids]
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device(device_type, self.device_ids[0])


        if len(self.device_ids) == 1:

Forward propagation

    def forward(self, *inputs, **kwargs):
    	# If no GPU is available, the original module is used for calculation
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
		# Here should be the parameters and buffer of the judgment model.
        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError("module must have its parameters and buffers "
                                   "on device {} (device_ids[0]) but found one of "
                                   "them on device: {}".format(self.src_device_obj, t.device))

        # Allocate GPU input to each GPU evenly
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 
        # for forward function without any inputs, empty list and dict will be created
        # so the module can be executed on one device which is the first one in device_ids
        if not inputs and not kwargs:
            inputs = ((),)
            kwargs = ({},)

        if len(self.device_ids) == 1: # If there is only one given GPU, directly call the unparalleled module. Otherwise, go to the next step
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) # The replicate function is mainly about copying the model to multiple GPU s
        outputs = self.parallel_apply(replicas, inputs, kwargs) # The model is calculated on multiple GPU s in parallel.
        return self.gather(outputs, self.output_device) # Aggregate the data together and transfer it to output_ On device, the default is also dim 0 dimension aggregation.

    def replicate(self, module, device_ids):
        return replicate(module, device_ids, not torch.is_grad_enabled())

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def parallel_apply(self, replicas, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

    def gather(self, outputs, output_device):
        return gather(outputs, output_device, dim=self.dim)
  • scatter function:
def scatter(inputs, target_gpus, dim=0):
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, None, dim, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
        res = scatter_map(inputs)
        scatter_map = None
    return res

   in forward propagation, the data needs to be allocated to each GPU through the scatter function, and the code is in scatter_ gather. Under py file, if the input type is not tensor, it will be processed into tensor according to the data type, and then call scatter recursively_ Map, and finally call scatter The apply method divides the data according to the given GPU and returns it.

  • replicate function:

  the replicate function needs to copy the model to each GPU. If the model you define is ScriptModule, that is, it is not inherited when writing your own model NN Module, but inherited NN ScriptModule cannot be copied and an error will be reported.

   this function is mainly used to copy the model parameters, buffer and other information that needs to be shared to each GPU. See for yourself if you are interested.


def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
    r"""Evaluates module(input) in parallel across the GPUs given in device_ids.

    This is the functional version of the DataParallel module.

        module (Module): the module to evaluate in parallel
        inputs (Tensor): inputs to the module
        device_ids (list of int or torch.device): GPU ids on which to replicate module
        output_device (list of int or torch.device): GPU location of the output  Use -1 to indicate the CPU.
            (default: device_ids[0])
        a Tensor containing the result of module(input) located on
    if not isinstance(inputs, tuple):
        inputs = (inputs,) if inputs is not None else ()

    device_type = _get_available_device_type()

    if device_ids is None:
        device_ids = _get_all_device_indices()

    if output_device is None:
        output_device = device_ids[0]

    device_ids = [_get_device_index(x, True) for x in device_ids]
    output_device = _get_device_index(output_device, True)
    src_device_obj = torch.device(device_type, device_ids[0])

    for t in chain(module.parameters(), module.buffers()):
        if t.device != src_device_obj:
            raise RuntimeError("module must have its parameters and buffers "
                               "on device {} (device_ids[0]) but found one of "
                               "them on device: {}".format(src_device_obj, t.device))

    inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
    # for module without any inputs, empty list and dict will be created
    # so the module can be executed on one device which is the first one in device_ids
    if not inputs and not module_kwargs:
        inputs = ((),)
        module_kwargs = ({},)

    if len(device_ids) == 1:
        return module(*inputs[0], **module_kwargs[0])
    used_device_ids = device_ids[:len(inputs)]
    replicas = replicate(module, used_device_ids)
    outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
    return gather(outputs, output_device, dim)

   parallel models and data are available, and then parallel models and data are used for calculation.

  • parallel_apply function:
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
	# Judge whether the number of models and the number of input data are equal
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()

    def _worker(i, module, input, kwargs, device=None):
        if device is None:
            device = get_a_var(input).get_device()
            with torch.cuda.device(device), autocast(enabled=autocast_enabled):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                output = module(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception:
            with lock:
                results[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, kwargs, device))
                   for i, (module, input, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
        for thread in threads:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
    return outputs

  first judge whether the length of the data meets the requirements. Then use multithreading to process data. Finally, gather all data together. The default is to gather together from the 0th dimension.


import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from import Dataset, DataLoader

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length = torch.randn(length, size)

    def __getitem__(self, index):

    def __len__(self):
        return self.len

class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.sigmoid = nn.Sigmoid()
        # self.modules = [self.fc, self.sigmoid]

    def forward(self, input):
        return self.sigmoid(self.fc(input))

if __name__ == '__main__':
    # Parameters and DataLoaders
    input_size = 5
    output_size = 1
    batch_size = 30
    data_size = 100

    rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                             batch_size=batch_size, shuffle=True)

    model = Model(input_size, output_size)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model).cuda()

    optimizer = optim.SGD(params=model.parameters(), lr=1e-3)
    cls_criterion = nn.BCELoss()

    for data in rand_loader:
        targets = torch.empty(data.size(0)).random_(2).view(-1, 1)

        if torch.cuda.is_available():
            input = Variable(data.cuda())
            with torch.no_grad():
                targets = Variable(targets.cuda())
            input = Variable(data)
            with torch.no_grad():
                targets = Variable(targets)

        output = model(input)

        loss = cls_criterion(output, targets)

Tags: Pytorch

Posted by Devil_Banner on Tue, 03 May 2022 17:24:12 +0300