jax
acb56a29 - [export] Fix calling under pmap of exported computation with polymorphic shapes

Commit
1 year ago
[export] Fix calling under pmap of exported computation with polymorphic shapes If we call a computation with shape polymorphism under pmap we must refine the shapes before we compile. We follow the same pattern for `UnloadedPmapExecutable` as for `UnloadedMeshExecutable`: we store the `shape_poly_state` from the `LoweringResult` into the `compile_args` and we call `refine_polymorphic_shapes`. Without this fix we may end up trying to compile HLO with dynamic shapes.
Author
Committer
Parents
Loading