pytorch
b56ba296 - Support multiple input dims for sharded linear. (#70266)

Commit
2 years ago
Support multiple input dims for sharded linear. (#70266) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70266 Addresses some of the issues mentioned in https://github.com/pytorch/pytorch/issues/65638. ShardedLinear implementation only support 2D inputs. On the other hand `nn.Linear` supports arbitrary dimensions for inputs and outputs. As a result, in this PR I've added support to ensure that ShardedLinear supports arbitrary input dims as well. ghstack-source-id: 147206607 Test Plan: waitforbuildbot Reviewed By: wanchaol Differential Revision: D33267630 fbshipit-source-id: 0460994c3aa33348b80547d9274206ef90cb29b6 (cherry picked from commit 7c289e1dbf491008e091ed0a49f98f2ebcfb4175)
Author
Committer
Parents
Loading