xla
4f378e69 - [SPMD] Patch nn.Linear (#5491)

Commit
2 years ago
[SPMD] Patch nn.Linear (#5491) Summary: This pull request introduces a patched version of torch.nn.functional.linear that uses einsum instead of torch.matmul which will flatten the tensors to 2D and collide the sharded dimensions. The torch.matmul default behavior makes it very hard for XLA compiler to propagate the sharding annotation. Test Plan: PJRT_DEVICE=CPU python test/test_operations.py -v -k test_patched_linear
Author
Parents
Loading