xla
5c4ce2a9 - [SPMD] Propagate replicated output (#5508)

Commit
2 years ago
[SPMD] Propagate replicated output (#5508) Summary: During the LLaMA2 experiements, I disovered that manually marking 1D tensors to be replicated can greatly save a lot of memory. Then I disocvered that explicitly replicated spec will get dropped after mark_step. That is caused by PrepareOutputShardingPropagation where it explicitly clear the sharding spec for replicated output. So, I went ahead and fix that. Further, I did some experiements of propogating replicated output and that drop the requirements of manually replicating 1D tensors. Hence, I made this change. I'm still not quite sure why, will follow up later. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py
Author
Parents
Loading