xla
1845cb77 - [WIP][SPMD] Propogate replicated output

Commit
2 years ago
[WIP][SPMD] Propogate replicated output Summary: During the LLaMA2 experiements, I disovered that manually marking 1D tensors to be replicated can greatly save a lot of memory. Then I disocvered that explicitly replicated spec will get dropped after mark_step. That is caused by PrepareOutputShardingPropagation where it explicitly clear the sharding spec for replicated output. So, I went ahead and fix that. Further, I did some experiements of propogating replicated output and that drop the requirements of manually replicating 1D tensors. Hence, I made this change. I'm still not quite sure why, so this is WIP. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py
Author
Committer
Parents
Loading