jax
c8b65463 - Add `is_ref` to `ShapeDtypeStruct` to allow doing AOT with duck types.

Commit
242 days ago
Add `is_ref` to `ShapeDtypeStruct` to allow doing AOT with duck types. Users create `ShapeDtypeStruct` with concrete mesh and pspec and pass those to `jax.jit` at the top level. So the reason for adding this to `ShapeDtypeStruct` and not creating a `NewStruct = tuple[ShapedArray | AbstractRef, ExtraMetadata]` is because: `ShapedArray`'s `.sharding` attribute only mentions the abstract mesh and explicit pspec. So if the user had `NamedSharding(Mesh((x, 2, Auto), (y, 2, Auto)), P('x', 'y')` on the input `NewStruct`, we would lose the device information and the pspec information since the pspec mentions only explicit axes. So the user will have to pass that information again (redundantly) in the `ExtraMetadata` section of `NewStruct`. This doesn't seem ideal :( Adding a `is_ref` field to `ShapeDtypeStruct` sidesteps these concerns while still keeping the door open for a better system to create duck type inputs when there are multiple user types (jax.Array, ArrayRef, Tuple, etc). PiperOrigin-RevId: 781697713
Author
Parents
Loading