pytorch
b9e3d722 - Use appropriate dtype for sharded linear implementation.

Commit
2 years ago
Use appropriate dtype for sharded linear implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79255 We use several collective operations in our sharded linear implementation and for many collectives, we do not set the `dtype` of the output tensor appropriately. As a result, using a datatype like torch.float16 (which is not the default torch.float32) results in errors. Fixing this across the board and adding appropriate tests. Differential Revision: [D37059752](https://our.internmc.facebook.com/intern/diff/D37059752/) Approved by: https://github.com/fduwjj, https://github.com/wanchaol
Committer
Parents
Loading