[WIP][SPMD] Propogate replicated output
Summary:
During the LLaMA2 experiements, I disovered that manually marking 1D tensors
to be replicated can greatly save a lot of memory. Then I disocvered that
explicitly replicated spec will get dropped after mark_step. That is caused
by PrepareOutputShardingPropagation where it explicitly clear the sharding
spec for replicated output. So, I went ahead and fix that.
Further, I did some experiements of propogating replicated output and that
drop the requirements of manually replicating 1D tensors. Hence, I made
this change.
I'm still not quite sure why, so this is WIP.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py