02c19e96 - Make `jax.grad` and `compute_on` work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
Make `jax.grad` and `compute_on` work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
PiperOrigin-RevId: 634917402