jax
9dcb93f0 - #pallas Avoid scaling the pallas_call CostEstimate by the batching dimension size if it is a dynamic shape.

Commit
204 days ago
#pallas Avoid scaling the pallas_call CostEstimate by the batching dimension size if it is a dynamic shape. When using dynamic shapes (exporting module) we can't accurately scale the cost estimates. In this case just strip them from the pallas call. PiperOrigin-RevId: 827901821
Parents
Loading