pytorch
7c44d560 - [PT-D][Sharding] Enable ops needed in the transformer model training (#75374)

Commit
2 years ago
[PT-D][Sharding] Enable ops needed in the transformer model training (#75374) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75374 From the code base of FairSeq and MetaSeq codebase (which is essentially a transformer model), we have found that loads of ops are not supported by sharded tensor. So we now implement a simple version so that we can at least run a transformer example: Ops include: chuck, transpose, view, mask_fill, dropout, softmax and type_as. Isolate the common logic of registering simple ops into a function and for future register, we just need to implement at most three functions for a new op. ghstack-source-id: 155309147 Test Plan: CI Reviewed By: pritamdamania87 Differential Revision: D35123021 fbshipit-source-id: 660e559fb8b4a910eb63e0586c63ab927873a2ce (cherry picked from commit 83a87ebf627d863448dfe1019c7c5f7112cc14ab)
Author
Committer
Parents
Loading