jax
dd958adc - Add `mesh_shape` to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.

Commit
1 year ago
Add `mesh_shape` to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them. Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions. Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before` In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses PiperOrigin-RevId: 660177423
Author
Committer
Parents
Loading