jax
33311b50 - [Pallas] Support the new TPU interpret mode with pl.core_map.

Commit
1 year ago
[Pallas] Support the new TPU interpret mode with pl.core_map. NOTE: The new TPU interpret mode does not yet support Megacore, so this only enables pl.core_map over a TensorCoreMesh with shape (axis_name, 1). Also adds a num_cores argument to pltpu.create_tensorcore_mesh.
Author
Committer
Parents
Loading