Fix ZeRO sort to be by numel (#60556)
Summary:
**Overview:**
This is a follow-up to [this PR](https://github.com/pytorch/pytorch/pull/59586) and corrects the ZeRO partitioning algorithm to sort by the number of elements in the tensor rather than the size of the first dimension. As context, that PR was meant to migrate from using a _naive greedy_ algorithm to a _sorted-greedy_ algorithm when partitioning parameters in ZeRO.
**Updated Results:**
The updated table for the partitions can be found [here](https://github.com/pytorch/pytorch/pull/59410#issuecomment-865203219). There, I also considered a third algorithm (sometimes known as multifit), which is more computationally expensive than the greedy and sorted-greedy algorithms but cannot perform worse. However, because of its increased complexity and lack of improved results, I chose to settle with the simpler sorted-greedy algorithm.
The `step()` latencies show slight improvements, but the improvements may be in the noise. The values below are in seconds and were generated using NCCL backend (unlike in the previous PR which used Gloo):
Two processes:
| Model | Max `optimizer.step()` Time - Greedy (Std.) | Max `optimizer.step()` Time - Sorted-Greedy (Std.) |
| --- | --- | --- |
| ResNet-50 | 0.047 (0.00142) | **0.044 (0.00025)** |
| ResNet-152 | 0.057 (0.00034) | **0.054 (0.00022)** |
| BERT | 0.021 (0.00008) | **0.020 (0.00008)** |
Four processes:
| Model | Max `optimizer.step()` Time - Greedy | Max `optimizer.step()` Time - Sorted-Greedy (Std.) |
| --- | --- | --- |
| ResNet-50 | 0.019 (0.00065) | **0.013 (0.00040)** |
| ResNet-152 | 0.045 (0.00024) | 0.045 (0.00025) |
| BERT | 0.019 (0.00022) | **0.018 (0.00016)** |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60556
Test Plan:
I verified that the ZeRO tests pass (via the AI AWS cluster):
```
srun -p $DEV_QUEUE --cpus-per-task=16 -t 5:00:00 --gpus-per-node=4 python test/distributed/optim/test_zero_redundancy_optimizer.py
```
Reviewed By: VitalyFedyunin
Differential Revision: D29335260
Pulled By: andwgu
fbshipit-source-id: 469d1c6e029b77c1b300a94cd1fd94b633cd28dd