pytorch
87088023 - Enables configuration of NCCL communicators (#97394)

Commit
1 year ago
Enables configuration of NCCL communicators (#97394) NCCL 2.17+ introduces some user configurable parameters for NCCL communicators using [ncclConfig_t](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#c.ncclConfig_t) datatype and [ncclCommInitRankConfig](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrankconfig). This PR enables that feature. A user can tune the parameters as follows: ``` import torch.distributed as dist nccl_options = dist.ProcessGroupNCCL.Options() nccl_options.config.max_ctas = 32 nccl_options.config.min_ctas = 8 nccl_options.config.cga_cluster_size = 2 dist.init_process_group(backend='nccl', init_method='env://', pg_options=nccl_options) my_group = dist.new_group(pg_options=nccl_options) ``` The default values of these parameters are what is initialized by `NCCL_CONFIG_INITIALIZER`. Only for DistributedDataParallel, this PR sets the default value of cga_cluster_size to 2 (a heuristic that works well especially for DDP workloads). Tuning these parameters can lead to improvement in end-to-end performance, since it affects the communication-computation overlap for NCCL kernels. CC: @ptrblck @kwen2501 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97394 Approved by: https://github.com/kwen2501
Author
Committer
Parents
Loading