Support multiple input dims for sharded linear. (#70266)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70266
Addresses some of the issues mentioned in
https://github.com/pytorch/pytorch/issues/65638. ShardedLinear implementation
only support 2D inputs.
On the other hand `nn.Linear` supports arbitrary dimensions for inputs and
outputs. As a result, in this PR I've added support to ensure that
ShardedLinear supports arbitrary input dims as well.
ghstack-source-id: 147206607
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33267630
fbshipit-source-id: 0460994c3aa33348b80547d9274206ef90cb29b6
(cherry picked from commit 7c289e1dbf491008e091ed0a49f98f2ebcfb4175)