jax
602f212d - [pallas] Make the LoweringDynamicShapeEnv use local mappings

Commit
21 days ago
[pallas] Make the LoweringDynamicShapeEnv use local mappings This fixes Pallas behavior in presence of shape polymorphism. Previously, the LoweringDynamicShapeEnv was using global mappings dim_expr_to_placeholder and placeholder_to_dim_expr. This results in errors where we carry-over dim_expr from one lowering to the next. Concretely, this resulted in failures of the form ``` Encountered dimension variable 'b' that is not appearing in the shapes of the function arguments. ``` because the symbolic variable 'b' is carried over from a previous test with Pallas and shape polymorphism.
Author
Committer
Parents
Loading