Add `is_ref` to `ShapeDtypeStruct` to allow doing AOT with duck types.
Users create `ShapeDtypeStruct` with concrete mesh and pspec and pass those to `jax.jit` at the top level. So the reason for adding this to `ShapeDtypeStruct` and not creating a `NewStruct = tuple[ShapedArray | AbstractRef, ExtraMetadata]` is because:
`ShapedArray`'s `.sharding` attribute only mentions the abstract mesh and explicit pspec. So if the user had `NamedSharding(Mesh((x, 2, Auto), (y, 2, Auto)), P('x', 'y')` on the input `NewStruct`, we would lose the device information and the pspec information since the pspec mentions only explicit axes.
So the user will have to pass that information again (redundantly) in the `ExtraMetadata` section of `NewStruct`. This doesn't seem ideal :(
Adding a `is_ref` field to `ShapeDtypeStruct` sidesteps these concerns while still keeping the door open for a better system to create duck type inputs when there are multiple user types (jax.Array, ArrayRef, Tuple, etc).
PiperOrigin-RevId: 781697713