jax
678adf59 - [Mosaic GPU] Fix issue in `vector_dim` reduction with `vec_len > 2`.

Commit
1 day ago
[Mosaic GPU] Fix issue in `vector_dim` reduction with `vec_len > 2`. `_lift_fast_packed_instr`, which is used for `min` and `max` reductions for `f16` and `bf16` wraps its output inside a `vector` type, even for a single element. When reducing across a `vector_dim` with more than 2 elements, this yields a type mismatch on the third reduction loop step: ``` # Step 1 scalar_out_reg = None scalar = vector.extract ... : f16 -> scalar_out_reg = scalar : f16 # Step 2 scalar_out_reg : f16 scalar = vector.extract ... : f16 -> scalar_out_reg = self._lift_fast_packed_instr(...)(scalar, scalar_out_reg) : vector<1xf16> # Step 3 scalar_out_reg : vector<1xf16> scalar = vector.extract ... : f16 -> Type mismatch ``` PiperOrigin-RevId: 855689624
Author
Parents
Loading