jax
d3c8de76 - Allow `prng.KeyTy` as a valid `RooflineShape.dtype`.

Commit
221 days ago
Allow `prng.KeyTy` as a valid `RooflineShape.dtype`. Prior to this change, primitives from `prng` (e.g. `random_wrap`) raised `TypeError`s because their inputs/outputs are of type `prng.KeyTy`. This change supports these types as valid and tests that primitives like `random_wrap` work. jaxpr created by this test: ``` { lambda ; a:u32[2]. let b:key<fry>[] = random_wrap[impl=fry] a c:key<fry>[2] = random_split[shape=(2,)] b d:u32[2,2] = random_unwrap c e:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] d f:u32[2] = squeeze[dimensions=(0,)] e in (f,) } ``` PiperOrigin-RevId: 776664179
Author
Parents
Loading