[JAX] Use JAX test utils to assert array equality/closeness
This change updates some JAX API tests to use JAX's own test utils for
comparing array contents instead of using numpy operations. This gives a
consistent testing pattern, and a stronger check around array dtypes.
PiperOrigin-RevId: 783110873