Am I misunderstanding Zero-1 and Zero-2?

#94
by Guanghua - opened

image.png

From the steps and the graphs describing Zero-1, it is using reduce-scatter for gradient aggregation instead of all-reduce. Isn't this the same as Zero-2? What are the real difference between these two?

I feel that the reduce-scatter in Zero-1 should be all-reduce. But could be missing something.

Hey, I think Reduce-Scatter is the correct operation in ZeRO-1 because, after the optimizer states are partitioned, the gradients needed for each GPU are also partitioned. The difference between AllReduce and Reduce-Scatter is that with AllReduce, each GPU gets the complete data, while with Reduce-Scatter, each GPU gets partitioned data. Since in both ZeRO-1 and ZeRO-2 the requirement for each GPU is partitioned data, Reduce-Scatter is the appropriate operation to perform. AllReduce is only needed in a naive DDP (Distributed Data Parallel) scenario, where each GPU requires the complete reduced gradients to perform the optimizer step.

Hi, I also see some contradicting information in the blogpost regarding the difference between ZeRO-1 and ZeRO-2. The blogpost mentions that 'Zero-1 change our "all-reduce" gradient communication to a "reduce-scatter" operation' (and the GIF all shows a reduce-scatter operation) but then proceeds to say that for ZeRO-2 'During the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation!'.
I would appreciate some clarifications on the subject.

I'm pretty sure that both ZeRO-1 and ZeRO-2 use reduce-scatter, so the ZeRO-1 part—"ZeRO-1 changes the 'all-reduce' gradient communication to a 'reduce-scatter' operation"—is correct.

As for the ZeRO-2 part—"During the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation!"—the phrase "instead of" can be ambiguous, as it's unclear whether it's comparing to the naive approach or to ZeRO-1.

A clearer version might be:

"During the backward pass, since each rank only needs partitioned gradients for the update, we also perform a reduce-scatter operation, just like in ZeRO-1. However, in ZeRO-2, each rank only needs to store 1/Nd of the gradients, which leads to even more memory savings compared to ZeRO-1."

@DandinPower

I think what @Guanghua means is that Zero2 is incorrect. Actually, under the Zero2 strategy, due to the simultaneous existence of data parallelism, if we want to keep only one copy of gradients on one card and delete the gradient copies on other cards, we need to perform a reduce operation first to ensure that the copies are synchronized to the unique copy before deletion. However, the animation of Zero2 in the text actually shows the situation under Zero1 (the figure does not delete gradient copies from different cards).

By the way, I'm quite curious about how the efficiency of methods that shard activation values compares to using gradient checkpointing. If anyone knows about this, please @ me. Thank you very much!

Sign up or log in to comment