flax
Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad.
#4617
Merged

Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad. #4617

copybara-service merged 1 commit into main from test_735883853
copybara-service
copybara-service copybara-service force pushed from cc8bb66b to 7b132af4 1 year ago
copybara-service copybara-service force pushed from 7b132af4 to f7cfcfbd 364 days ago
copybara-service copybara-service force pushed from f7cfcfbd to 44ab6027 363 days ago
copybara-service copybara-service force pushed from 44ab6027 to df0e7156 363 days ago
copybara-service copybara-service force pushed from df0e7156 to 8cd8994c 362 days ago
copybara-service copybara-service force pushed from 8cd8994c to 3cf75976 362 days ago
copybara-service copybara-service force pushed from 3cf75976 to 6b676013 362 days ago
jburnim Stop passing reduce_axes to jax.grad, jax.vjp, and jax.value_and_grad.
f18baf07
copybara-service copybara-service force pushed from 6b676013 to f18baf07 362 days ago
copybara-service copybara-service merged f18baf07 into main 362 days ago
copybara-service copybara-service deleted the test_735883853 branch 362 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
No one assigned
Labels
Milestone