jax
75b11ec2 - 7x device mesh fall back to unoptimized when needed

Commit
207 days ago
7x device mesh fall back to unoptimized when needed Without this change creating a mesh of 32 devices out of 64 devices fails on v7x-64. The following error will be seen: ``` JAX device count: 64 Using 32 devices for the mesh. Attempting jax.make_mesh((1, 32), ('fsdp', 'tp')) AssertionError: Unexpected physical mesh shape: (2, 4, 2, 2). ``` PiperOrigin-RevId: 828008412
Author
Parents
Loading