Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.
We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.
Fixes https://github.com/jax-ml/jax/issues/26888