jax
eab1dfcc - [Pallas] Generalize BlockSpec to support different indexing mode for each dim in the block shape

Commit
323 days ago
[Pallas] Generalize BlockSpec to support different indexing mode for each dim in the block shape Currently block_shape is tuple[int | None, …]. We propose generalizing block_shape to take in more types in the tuple to more generally support: * Squeeze dimension (currently None, could be pl.Squeezed()) * Unblocked: currently the entire index_map needs to be Unblocked or not. This will allow individual indices to be Blocked/Unblocked, e.g. pl.BlockSpec((pl.Unblocked(...), 512), …) * Ragged sizes: the index_map will return a pl.ds with a dynamic size (bounded by some something). For example: pl.BlockSpec((pl.DynamicSizedSlice(512), 1024), lambda i, j: (pl.ds(...), j). This will make BlockSpecs a lot more flexible and will enable things like doing arbitrary slicing in things like pipeline emitter. PiperOrigin-RevId: 748881960
Author
Parents
Loading