jax
0a66e2d0 - [Pallas/MGPU] Fix a race in the pipelining code

Commit
1 year ago
[Pallas/MGPU] Fix a race in the pipelining code We never checked if the output windows are done writing before we reused them. Also, rename num_stages to max_concurrent_steps since we always only have 2 stages, but might be running multiple iterations at a time. Also fix the test for this that has been passing for reasons that I don't understand (it didn't even write to all entries in the output??). PiperOrigin-RevId: 679148961
Author
Parents
Loading