pytorch
44aa4ad8 - Use `_all_gather_base` and fuse matmul for sharded linear.

Commit
2 years ago
Use `_all_gather_base` and fuse matmul for sharded linear. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78477 Use `_all_gather_base` instead of all_gather for col-wise sharding since `_all_gather_base` returns a single fused tensor that can be used to perform a single matmul instead of looping through and performing multiple matmuls. This improves performance for col-wise sharding. Differential Revision: [D36754385](https://our.internmc.facebook.com/intern/diff/D36754385/) Approved by: https://github.com/aazzolini, https://github.com/wanchaol
Committer
Parents
Loading