GSPMD Style: Pipeline Parallelism For Distributed LLM Training
Overview of Pipeline Parallelism, GPipe, GSPMD style Pipeline Parallelism, and its application to LLM Trainings
In this analysis, we look into parallelism strategy widely used in LLM training called Pipeline Parallelism. Furthermore, we look into different pipeline parallelism schedules to “maximize intelligence per TCO” by reducing worker idleness, almost known as “the bubble”. In a popular ML framework called JAX, they code their pipeline parallelism in a SPMD style that is extremely confusing. We will break down line by line with the help of animations how JAX SPMD pipeline parallelism works. Lastly, we will briefly chat about why pipeline parallelism is only used on GPUs but not TPUs.
Pipeline Parallelism
When can’t fit all parameters and optimizer state on each GPU, you are unable to use the traditional distributed data parallel, thus you would need to split up different layers to put on different GPUs. If before, you had 100GBytes of parameters and optimizer state, if you apply a 4 stage pipeline parallelism strategy across 4 GPUs, then each gpu would only use 20GBytes of GPU memory to store the parameters and optimizer state.
Now that we talked about the advantage of pipeline parallelism, the major drawback is that as you can see in the diagram above, there is plenty of idle time. In the industry, we called that “the bubble”. When you have idle time, your tokens per TCO goes down and your “intelligence per picojoule” goes down too as your GPUs are staying idle most of the time but you still need fans to cool them down and it takes awhile for your GPUs to throttle down (order of tens of ms) their power usage when they not running kernels.
There are plenty of different pipeline schedules in order to minimize the bubble but each has trade offs. For this blog post, we will talk about the most widely used one called GPipe.
GPipe
GPipe solves this by splitting a single batch into multiple microbatches. With enough micro batches, the "bubble" is very small. Within a batch, each micro batch uses the same weights and accumulates the gradients until all the micro-batches do a forward and backwards pass through all the stages. Then once for that batch, all the microbatch finished their fwd/bwd pass, each stage will in parallel apply the gradients (with an optimizer such as Adam, SGD, etc) Then this process repeats for the next batch. I would recommend playing the video and pausing as each worker independently processes a minibatch.
Animation Above is an original animation.
As you can our bubble between the forward and backward pass gets way smaller but we still have the problem that the startup and cooldown at the end of the batch still has idle time. This is fine as long as there is enough microbatches to amortized the startup and cooldown period.
SPMD Style GPipe
In JAX, you can only program in single program multiple data (SPMD) style. This dense piece of code confuses everyone from dropouts to PhDs. A couple years ago, I spent days trying to understand exactly how it works.
In order to vectorize an OneStageCompute, you use the jax vmap to vectorize it across # of stages. Next, you need to create a shifting state buffer to pass the output of one stage to the input of another stage (line 1). Your state needs to be initially padded with zeros and you also need to pad the end of your inputs by # of stages - 1, so that all of our inputs can go through the whole pipeline. Note that given that you vectorize over the OneStageCompute and it is an SPMD programming model, so all workers will be running the same computation except with different data and weights. The input buffer is just the same as state except for the first element where is the next_input. (look at elementwise_select line) Also note, the “input buffer” is probably optimized away by compiler but for illustration purposes, i have included it in the visualization.
I have created the animation below. I believe the best way to understand SPMD pipeline to read the code and play/pause the animation side by side. It took me a couple of days to exactly understand from first principles how it works.
Implications on TPUs and GPUs
The reason why pipeline parallelism is used on GPUs but not TPUs is that the current generation of GPUs, H100s, MI300X, only have 400Gbit/s of networking to other thousands of other chips. With Google’s TPUv5p, you have ICI of 4,800 Gbit/s per chip in a 3D torus topology, so that you only need to utilize fully shared data parallel and tensor parallelism in order to fit your model without communication bottlenecks. This is one of the main advantages of TPUs is that in order to scale to tens of thousands, pipeline parallelism is not required. Generally, ML System engineers hate reasoning and debugging pipeline parallelism. It is way easier to just change the mesh size to scale (like in TPUs) than having to add more parallelism strategies when scaling on GPUs.
Furthermore, chips without enough HBM need to use pipeline parallelism too, such as Groq and Tenstorrent chip.
We will dive into the other pipeline parallelism schedules in later articles as animating these schedules take a lot of time.
Not sure I agree on the last section. I believe Google has better pipeline parallelism libraries that run much more efficiently internally.