Raise an error on non-hashable static arguments for jax.jit and xla_computation.
Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`:
```
class WrapHashably(object):
__slots__ = ["val"]
def __init__(self, val):
self.val = val
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return self.val is other.val
```
This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with:
@partial(jax.jit, static_argnums=(1,))
def sum(a, b):
return a+ b
sum(np.asarray([1,2,3]), np.asarray([4,5,6])
# The next line will recompile, because the 1-indexed argument is non
# hashable and thus compared by identity with different instances
sum(np.asarray([1,2,3]), np.asarray([4,5,6])
or more simply
np.pad(a, [2, 3], 'constant', constant_values=(4, 6))
^^^^^^
non-hashable static argument.
The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about).
If this commit breaks you, you usually have one of the following options:
- If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static.
- When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset.
- You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function.
PiperOrigin-RevId: 339351798