jax
50e2a612 - make a "hello world" of pallas_call + pinned buffers

Commit
80 days ago
make a "hello world" of pallas_call + pinned buffers The purpose of this change is to adapt `pallas_call` so as to accept the type produced by `pin` (currently called `AbstractLinVal`), and to produce a value of the same type at the corresponding aliased output (as indicated by `input_output_aliases`, which is now required to be populated for all such arguments). While the `pallas_call` can be applied to values of type `AbstractLinVal`, and it can produce values of the same type, the corresponding binders on the kernel body function remain of type `AbstractRef(ShapedArray)`. So this change introduces an intentional divergence between the avals consumed and produced by the `pallas_call` and those of the kernel body. The main changes are: 1. In `_pallas_call_abstract_eval`, we now produce output avals corresponding to entries in `input_output_aliases` by reading the input avals, thus matching any `AbstractLinVal`s passed in the input `avals`. We don't use the parameter `out_avals`, since those more closely reflect the kernel body's output types, based on how they are inferred given the user-provided `out_shape`. 2. In the `pallas_call` lowering rule, we set `kernel_out_avals` according to the view-from-the-caller `ctx.avals_out` so as to pass in the `AbstractLinVal`s and ultimately generate an MHLO lowering that produces the memref types we intend. PiperOrigin-RevId: 819479255
Author
Parents
Loading