jax
e86ec28c - Recast int/bool tangents to float0 in custom_jvp/vjps (also in the

Commit
5 years ago
Recast int/bool tangents to float0 in custom_jvp/vjps (also in the initial_style path). PiperOrigin-RevId: 336082045
References
Author
Committer
Parents
Loading