jax
7f05b74b - Fix wrong results in multidimensional pad.

Commit
336 days ago
Fix wrong results in multidimensional pad. When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order. We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths. Fixes https://github.com/jax-ml/jax/issues/26888
Author
Committer
Parents
Loading