jax
a25a24df - [Mosaic:GPU] Add test-case for several mosaic ops.

Commit
2 days ago
[Mosaic:GPU] Add test-case for several mosaic ops. This test covers a case when on the XLA side there is an HLO with several Mosaic custom calls. The problem here is that buffer assigner can reuse the physical buffer assigned for additional scratch output buffer (used to copy collective metadata) across several operations in HLO. It might lead to the following situation: Thunk 1 initialize: stores metadata 1 Thunk 2 initialize: stores metadata 2 Thunk 1 execute: uses metadata from thunk 2 To prevent this right now we are copying collective metadata to the scratch buffer at every execution and this test-case should cover this logic. We also disable cuda graphs for the test in this test since they are currently not compatible with collective kernels (neither with NVSHMEM not with collective metadata). PiperOrigin-RevId: 876265600
Parents
Loading