[SPMD] Add apply_backward_optimization_barrier
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier
where registers a full backward hook that apply an optimization barrier to the given module.
This API will prevent the XLA compiler from fusing the module's backward pass with others.
And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.
Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier