jax
bd8765d3 - Add collective_axes to run_scoped

Commit
293 days ago
Add collective_axes to run_scoped Our current allocation scheme on GPU is unsafe in presence of multiple threads that might take diverging control paths. We work around this problem using our favorite trick and simply forbid this! With this change, `run_scoped(..., collective_axes="wg")` means that the same allocation will be returned in all programs that only differ in the `wg` axis. What's more, this call is a user promise that the allocation is a collective that will be executed by all threads along that axis. Only executing it on a subset is undefined behavior and in our current Mosaic GPU implementation might lead to deadlocks due to barriers. Note that nothing changes for single-threaded kernels, where run_scoped is always allowed. PiperOrigin-RevId: 757734362
Author
Parents
Loading