[SPMD] Patch nn.Linear (#5491)
Summary:
This pull request introduces a patched version of torch.nn.functional.linear that uses einsum instead of torch.matmul which will flatten the tensors to 2D and collide the sharded dimensions. The torch.matmul default behavior makes it very hard for XLA compiler to propagate the sharding annotation.
Test Plan:
PJRT_DEVICE=CPU python test/test_operations.py -v -k test_patched_linear