flax
Improve partial_eval_by_shape compatibility with https://github.com/google/jax/pull/3370 and custom derivative rules.
#405
Merged

Loading