jax
e4fa0259 - [Pallas:MGPU] Lower `lax.sign` consistently for LANE and WG semantics.

Commit
137 days ago
[Pallas:MGPU] Lower `lax.sign` consistently for LANE and WG semantics. We lower `lax.sign` as follow: * For floats: `sign(x) = select(x != 0, copysign(1.0, x), 0.0)` * For unsigned integers: `sign(x) = (x != 0)` * For signed integers: `sign(x) = (x > 0) - (x < 0)` PiperOrigin-RevId: 855219472
Author
Parents
Loading