Add `AbstractDevice` to `AbstractMesh` to capture specific device properties; `device_kind` and `num_tpu_cores`.
Main implication is that this puts `AbstractDevice` into tracing cache key so tracing the same function on CPU and TPU one after another would lead to a tracing cache miss. We think this is rare enough and it's fine to get a cache miss to solve the pallas issue. Note that this only happens if the user set the mesh context via `jax.set_mesh(mesh)`.
Pallas already specializes on `device_kind` and `num_tpu_cores` during tracing but in a weird way (via `get_default_device()`). This change just makes it more official and exposes a standard way of branching on such info during trace time which is also cache sensitive.
PiperOrigin-RevId: 796542802