jax
7448556f - Allow reduced on the fwd pass to jnp.sin even though unreduced is not allowed as an input to sin since sin is a non-linear op.

Commit
58 days ago
Allow reduced on the fwd pass to jnp.sin even though unreduced is not allowed as an input to sin since sin is a non-linear op. **Then why allow reduced to a non-linear op?** When unreduced is introduced on the backward pass (since it's the cotangent type of reduced), there will be no non-linearity present! Let's take the sin example. JAX will do the `cos` on the primal side and the tangent computation will only contain `mul(g, cos_residual)` and `mul` is a bilinear op which can ingest unreduced on one of the operands. **Those were a lot of words, let's look at a concrete example and see what the forward and backward pass will look like** Consider this example: ``` @jax.jit def f(x: f32[4, 2]{R: x}, y: f32[2, 8@x]): x_: f32[4, 2]{R: x} = jnp.sin(x) z: f32[4, 8@x] = x_ @ y return z.sum() ``` The backward pass would look like this if you do: `jit(grad(f))` ``` def f_bwd(res, dz): (cos_x: f32[4, 2]{R:x},) = res dx_: f32[4, 2]{U: x} = dz: f32[4, 8@x] @ y.T: f32[8@x, 2] dx: f32[4, 2]{U: x} = mul(dx_, cos_x) return dx ``` As you can see `mul` gets 2 inputs `dx_` and `cos_x` (which is the residual). Since mul is bilinear, if one of the inputs is unreduced along a mesh axis, then the other input **has to be reduced** along the same mesh axis. There's another thing to note: After grad, the type of `dx` will be `f32[4, 2]{U: x}` which makes sense because the type of `x` was `f32[4, 2]{R: x}`! PiperOrigin-RevId: 829096082
Author
Parents
Loading