jax
9def0f1c - [pallas] Add limited support for shape polymorphism for TPU

Commit
1 year ago
[pallas] Add limited support for shape polymorphism for TPU The main change is to pass the `result_shapes` to the hlo.CustomCallOp when the output shapes contain dimension variables. Everything else is already handled by the support for dynamic bounds sizes for TPU. Note that this CL only adds limited support for shape polymorphism: only on TPU, and only when the block sizes are static. PiperOrigin-RevId: 648409699
Author
Committer
Parents
Loading