[sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)
* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.
* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.
* scan only allows `xs` where the 0th dim is full replicated i.e. None.
PiperOrigin-RevId: 699014167