jax
70f3b181 - Create `pl.select_ref` to dynamically choose from multiple refs for DMAs.

Commit
19 days ago
Create `pl.select_ref` to dynamically choose from multiple refs for DMAs. This supports nesting select and existing transforms like indexing. Example use: ``` x_ref = pl.select_ref(idx, x0_ref.reshape(...), x1_ref).at[<...>] pltpu.async_copy(x_ref, y_ref, sem).wait() ``` Also expanded `TransformedRef` to support multi-ref cases, while keep the single-ref structure unchanged. PiperOrigin-RevId: 910330932
Author
Parents
Loading