xla
07540f21 - [SPMD] Add apply_backward_optimization_barrier (#6097)

Commit
2 years ago
[SPMD] Add apply_backward_optimization_barrier (#6097) Summary: This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients. It's also used in pytorch-tpu/transformers#50. Test Plan: python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
Author
Parents
Loading