revise axes_scan to flatten argument pytrees only once
A user has a custom pytree node with the unusual behavior that it introduces
new arrays when flattening. That is, it's as if we had:
```python
# a custom object with two leaf arrays
custom_tree_object = SomeObject(jax_arrray1, jax_array2)
# convert leaves to ShapedArrays
custom_tree_object2 = jax.tree.map(core.typeof, custom_tree_object)
# flatten, should only see ShapedArrays, right?
leaves, treedef = jax.tree.flatten(custom_tree_object2)
print(leaves)
# [ShapedArray(...), ShapedArray(...), np.array(...)]
```
This change makes the `flax.nn.scan` function robust to such behavior. Without it, we were passing non-AbstractValues into JAX where JAX required AbstractValues.
I don't think we want to support this in general, but this fix seemed like the most
expedient way to roll fowrard https://github.com/jax-ml/jax/pull/29273
PiperOrigin-RevId: 768175118