jax
355589f3 - [sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here

Commit
1 year ago
[sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here * Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path) * Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager. * Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them. * scan only allows `xs` where the 0th dim is full replicated i.e. None. PiperOrigin-RevId: 699014167
Author
Parents
Loading