jax
02c19e96 - Make `jax.grad` and `compute_on` work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.

Commit
2 years ago
Make `jax.grad` and `compute_on` work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU. PiperOrigin-RevId: 634917402
Author
Committer
Parents
Loading