Distributed Training in Deep Learning Models
Updated: Jul 30, 2021
A brief overview of deep learning
Deep learning is a sub-field of the broader spectrum of machine learning methods and has performed remarkably well across a wide variety of tasks such as image and speech recognition, natural language processing, autonomous driving, medical research, etc. Most commonly, deep learning is based on artificial neural networks. The architecture of a standard neural network is inspired by the way the human brain functions. It is composed of connected nodes called neurons, which, through a series of real-valued activations help the network to learn complex functions. A neural network consists of several stages including an input layer, hidden layers, and output layers.
Neural network representation ( image source )
The learning of a neural network involves finding an optimized set of parameters of these stages, that eventually help the network to exhibit the desired behavior. E.g: Predicting if an image is of a particular object of interest, or translating an English sentence to French. Shallow neural networks that consist fewer number of fewer stages have been in place since the 1960s. Deep neural networks, which are useful for learning highly complex problems, contain several layers or stages and require a large amount of training data for the best accuracy on the desired inference task. Although the concept has been around for several decades, deep learning has become popular in recent years due to the availability of improvised GPU hardware architecture, and human-generated digital data in large volumes.
Why distributed deep learning?
As discussed, to learn complex functions we require a large deep learning network with several layers. The classification accuracy of a deep learning model can increase with the increasing number of training examples, the number of model parameters, or both. However, training large networks is computationally expensive and can take an impractically long time when trained on a single machine, even if it supports multithreading. This calls for scaling up the training of these models across multiple connected machines in a distributed manner. With the increased availability of GPUs, the model training can be distributed across multi-GPU clusters. Before we dive further, it’s important to understand the concepts below that are related to the training of a neural network:
Weights The network is initialized with random weights. The weights are adjusted through multiple iterations during the training process with the aim of reducing the error in predicting the output. The intuition is to tune the weights such that the network produces an output as close as possible to the actual output for the corresponding input values. These weights are multiplied by the input to the neurons.
Gradient descent As we saw, the network needs to learn optimal weights for predicting the correct output. The loss of a single training example determines the error in predicting the output for the particular example. The cost function is the average of the loss functions of the entire training set and is a measure of how well the weights are doing on the training set. The aim is to keep the cost function as low as possible. The cost function J(w) modeled against the weights (w) is a convex function shaped like a bowl as given below:
Gradient descent ( image source )
Gradient descent algorithm aims to reach the minima of this cost function. At every iteration, it takes small steps along the downward slope and converges to the minima after several such steps.
Stochastic Gradient Descent (SGD): In gradient descent, the total number of examples used for calculating the cost function constitutes a batch. A large dataset may contain redundant data, and it is likely that by selecting a random example we get the desired gradient. The stochastic gradient descent approach considers only one training example for minimizing the cost function for much less computation, but at the cost of being noisy.
Mini-batch SGD: Instead of 1 example, SGD can be performed on mini-batches typically consisting of 10 to 1000 random examples.
There are primarily two strategies based on which the training or inference process can be partitioned across the available machines or nodes:
In data parallelism, the data is divided into partitions, where the number of partitions is equal to the total number of available nodes (computational resource) in the cluster. The model is replicated in each of these worker nodes. Each worker operates on its own subset of the data. There are several approaches for synchronizing the model weights or parameters across the workers. The simplest of these approaches is parameter averaging, whereby the global model parameters in the centralized parameter server are set to the average of the parameters from each worker after an iteration. The updated global parameters are then broadcasted to all the workers for the subsequent training. An alternative is Asynchronous Stochastic Gradient Descent, in which instead of the parameters the updates are transferred.
It is asynchronous in the sense the model replicas run independently of each other. Google’s deep learning framework DistBelief  uses a variant of the Asynchronous Stochastic Gradient Algorithm called Downpour SGD which is optimized to run on very large datasets. In this approach, before processing each mini-batch, a model replica asks the parameter server for an updated copy of its model parameters. After receiving an updated copy of its parameters, the DistBelief model replica processes a mini-batch of data to compute a parameter gradient and sends the gradient to the parameter server, which then applies the gradient to the current value of the model parameters. The Asynchronous SGD is more robust than standard SGD because even if one model replica fails, the other model replicas can continue their training and update their parameters to the server.
In model parallelism, the computations corresponding to different nodes of the neural network are carried out on different machines. The framework that supports model parallelism needs to automatically manage communication among the nodes and synchronize the training or inference process. In the DistBelief framework, the centralized parameter server is shared across many machines. If there are 10 parameter server shards, each shard is responsible for storing and applying updates to 1/10th of the model parameters. During the SGD process, each machine communicates with just the subset of parameter server shards that hold the model parameters relevant to its partition.
Distributing model training in PyTorch
The DataParallel wrapper class in the PyTorch package splits the input data across the available GPUs. The model is replicated on each device. The updated gradients from each replica are summed into the original module. Given below is the code snippet that demonstrates the usage of this class:model = Model(input_size, output_size)if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model)model.to(device)
Although distributed training of deep learning models helps to scale up the network, it comes with the overhead of synchronization and network transfer of parameters. It can be used when the improvements in the computational process outweigh the overheads. The partitioning strategy and approach to be followed depend on a number of factors such as the size of the network, cluster hardware, etc. The correct approach can be selected based on the training speed and accuracy achieved. Reference: Dean, Jeffrey, et al. "Large scale distributed deep networks."Advances in neural information processing systems. 2012. https://developers.google.com/machine-learning/crash-course/reducing-loss/stochastic-gradient-descent https://blog.skymind.ai/distributed-deep-learning-part-1-an-introduction-to-distributed-training-of-neural-networks/  Ben-Nun, Tal, and Torsten Hoefler. "Demystifying Parallel and Distributed Deep Learning: An In-Depth Concurrency Analysis."arXiv preprint arXiv:1802.09941(2018).  https://pytorch.org/docs/stable/_modules/torch/nn/parallel/data_parallel.html  https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html