[SPMD] Support mark_sharding on IRs (#5301)
Summary:
This pull requests fixes the recompilation issue in xs.mark_sharding().
xtensor->GetXlaData() will compile the program if xtensor is an IR in order
to get the BackendData. I believe this is not intended given the error message
below suggests only data type xtensors are supported.
Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py