jax
6ce4314f - lax: ensure padtype_to_pads returns Python ints

Commit
41 days ago
lax: ensure padtype_to_pads returns Python ints Fix a bug where padtype_to_pads could return NumPy scalar values in padding tuples for certain shapes and stride configurations, leading to verbose NumPy scalars in jaxprs. Add a regression test to ensure padding values are always Python ints.
Author
Committer
Parents
Loading