jax
1f0b5728 - Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.

Commit
1 year ago
Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call. The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op. PiperOrigin-RevId: 685827096
Parents
Loading