Flax is changing the `RNNCellBase` API:
- when calling the constructor of a class, it is now required to pass in a `features` argument
- when calling the `initialize_carry` method, instead of passing in the `batch_dims` and `size`, you only have to pass in an `input_shape`
More details about the changes and how to upgrade to the new API can be found [here](https://flax--3053.org.readthedocs.build/en/3053/guides/rnncell_upgrade_guide.html).
PiperOrigin-RevId: 544461085