The annual HOT CHIPS conference took place on August 17-18. Of course, it was virtual. As always, on the Sunday before there were two half-day tutorials. In the morning, it was on scaling deep learning training. In this context, "scaling" means running the training on a large number of machines. So in that sense, it is very similar to when EDA software "scales" to run in the cloud. One difference is that a lot of work has gone into building specialized hardware for deep learning training, and later in the morning we heard about several of the most advanced AI training "pods" around. I suppose if you squint just right, you could say that the Palladium and Protium platforms are specialized "pods" for scaling simulation.
The morning started (at 8:30am on a Sunday morning...no lie-in for me) with Paulius Micikevicius of NIVIDIA who gave a presentation Fundamentals of Scaling out DL Training. Compared to some of the presentations that came later in the morning, this was introductory, although it did assume you had a working knowledge of matrix algebra.
If you don't follow deep learning models on a regular basis, it is easy to miss that they are growing in size by an order of magnitude or more every year. For example, Paulius pointed out:
Although Paulius didn't mention it, the recently announced GPT-3 language model has 175B parameters and is supposedly the largest model ever produced. If you want to read more about it (recommended) then the MIT Review has an article OpenAI’s new language generator GPT-3 is shockingly good—and completely mindless. It would take 355 years and cost $4.6M to naively train (on the lowest priced GPU cloud server on the market).
Larger input datasets lead to higher accuracies, too. Recommender data (user behavior on Amazon, Netflix, or YouTube) has gone from terabytes to petabytes. Image data has gone from a few million Imagenet images to the 1B Instagram dataset.
So, in a problem that seems familiar in EDA, larger models no longer fit in a single processor and take too long to train. The solution is scale-out computing, using lots of computers (perhaps specialized ones) instead of trying to do everything on a single core.
Training a neural network works like this. Start by initializing all the weights to random values. Then take a minibatch of some of the training data (think images of cats and not-cats if you want something concrete). Do a forward pass through the current state of the neural network. Assess the errors. Do a backward pass. Update the weights. Move onto the next minibatch and repeat.
Let's look at a simplified example, a network of three linear layers. For each layer, the input is a vector, the output is a vector. The operation consists of multiplying the input vector with the matrix of weights, and then applying a point-wise nonlinearity such as Rel-U (don't worry too much about that bit).
So we start with the forward pass. The input data is loaded (top left) and the first (yellow) matrix multiplication is done (top right) to get an intermediate vector. That intermediate vector is multiplied with the second (purple) layer to get another intermediate vector (lower left) and finally that vector is multiplied by the final (blue) stage to get the output.
In practice, we wouldn't process just one vector at a time, we'd group them into minibatches (the size depends on the compute resources). The vector multiplies turn into matrix multiplies, as in the above diagram which is using a batch size of two.
Note also that all the intermediate vectors/matrices need to be retained, since they will be required again during the backward pass.
The next step is to calculate the loss function, which is a measure of how wrong the network was. Note that this requires what is known as labeled data, since we need to know the correct answer in each case so that we can compare it to what the neural network came up with. A dataset used in a lot of neural network university courses has thousands of handwritten digits (0-9) and the labels are the correct answer. If we create a neural network to process the handwritten images, then at the output we have a probability for each digit. Early in training, these will be all over the place (remember, we started with random weights) and later it should show a huge probability for the right answer. In the midst of training, you expect the right answer to have a higher probability but still, other digits show up as possibilities. Those values are how "wrong" the network was in assessing the digit. The goal of training is to minimize the loss value, namely update all the weights so that the output is closer to the correct answer.
The next stage is to update the weights so that next time the loss function should be smaller. This is done in two phases. First, the backward pass, which "back propagates" the loss through the layers. Each layer computes a weight gradient (to be used to update its own weights) and an activation gradient (used to backpropagate to the preceding layer). In the diagram above dY is the incoming activation gradient (from the following stage), X are the input activations (from the forward pass, these need to be kept around). That allows us to calculate dW, the weight gradient that will be used to update the weights. Using the current weights W we can calculate dX to backpropagate to the preceding layer (where it will become the dX of this same calculation for that layer).
Next, the weights are updated. This is actually where the smarts are in training, and it is also known as the optimizer step. There are lots of approaches such as SGD, Adam, Adagrad, and more. These take as input the current network weights, along with the weight gradients (the dW in the above calculation), and they update the weights. Although it is possible to simply update the weights using the weight gradient, in practice a lot of internal state is also retained, so the actual update is more like "update the state with the weight gradient, and then use the updated state to update the actual weights".
The internal state is important since each element might be one or two momenta, and even if training is using 16-bit precision, the momenta may need to be 32-bit. So in practice, the internal state may occupy 2-6X as much memory as the model itself.
So that's an (over) simplified view of how training is done. Some things to note:
As is usually the case when trying to run an algorithm in parallel on lots of servers, the challenge is:
To scale training, there are really three approaches:
I'm not going to go through all of these in detail. Let's look at data-parallel in a little detail, and then I'll just describe the others at a high level.
For data-parallel, each worker has a copy of the entire neural network model. It is responsible to compute a portion of the data (each worker gets a different minibatch). The forward and backward passes are just as I described above. The difference comes during weight update. Before the weight update can take place, the contributions from all the workers need to be summed, and then each worker updates its own copy of the model with the summed gradients. So, from a communication point of view, nothing is needed during the forward and backward pass, and then all the workers need to transmit their contribution to the weight update, and receive the updated weights.
There are some subtle ways to do this. Which is best will depend on the network topology:
Note also that communication can overlap a lot of computation. All the communication associated with layer K can be overlapped with computing the gradients for K-1.
Data-parallel scaling can be done in one of two ways. Strong scaling means add more workers but keep the mini-batch size the same. The problem with this is that for a small mini-batch size and a large model, there is a lot of communication to update the weights after each mini-batch. Weak scaling means to both add more workers and increase the size of the mini-batch (so that communication doesn't need to happen so often). This typically requires some adjustment to how the training is done and, perhaps, the weight update method used. Typically, the amount of work required to reach a given accuracy will increase, but that should be offset by having more workers.
As I said above, there are two forms of model-parallel. Interlayer (pipelined) or intralayer, where each worker is responsible for a portion of each layer. The above diagram shows the basic idea. In inter-layer, each worker is responsible for one or more layers of the entire model. In intralayer, the model is split up and each worker is responsible for part of the model (top, middle, and bottom, in the picture above, although obviously in the real world there will be more than three workers).
The challenge of the pipeline parallel inter-layer approach is that there is communication at each stage of the forward and backward passes (unlike with the data-parallel approach) and it is very hard to overlap communication with computation. It is also difficult to "load balance" across all the workers, since different layers may take different amounts of time to compute, meaning that some workers are busy but others are idle waiting for the busy workers to complete.
There are various tricks in intralayer to reduce communication, such as alternating between horizontal and vertical partitions, but this approach also suffers from it being difficult to overlap communication with computation. But when the model is too big for data parallelism you have to deal with it. Not being able to fit the model in a single server is a hard limit for data-parallel.
Data-parallel is the simplest way to scale out. But it requires each node to hold the entire model (all the weights) and that is a hard limit. So despite all the communications complexity, for the largest models, some form of model-parallelism is the only way to scale to a bazillion cores.
I'll write about the second part of the tutorial another time: NVIDIA's top-10 supercomputer, Google's TPU Pods, Cerebras's wafer-scale chip, and more.
Sign up for Sunday Brunch, the weekly Breakfast Bytes email.