[SPMD] Add apply_backward_optimization_barrier (#6097)
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.
It's also used in pytorch-tpu/transformers#50.
Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier