Allow reduced on the fwd pass to jnp.sin even though unreduced is not allowed as an input to sin since sin is a non-linear op.
**Then why allow reduced to a non-linear op?**
When unreduced is introduced on the backward pass (since it's the cotangent type of reduced), there will be no non-linearity present!
Let's take the sin example. JAX will do the `cos` on the primal side and the tangent computation will only contain `mul(g, cos_residual)` and `mul` is a bilinear op which can ingest unreduced on one of the operands.
**Those were a lot of words, let's look at a concrete example and see what the forward and backward pass will look like**
Consider this example:
```
@jax.jit
def f(x: f32[4, 2]{R: x}, y: f32[2, 8@x]):
x_: f32[4, 2]{R: x} = jnp.sin(x)
z: f32[4, 8@x] = x_ @ y
return z.sum()
```
The backward pass would look like this if you do: `jit(grad(f))`
```
def f_bwd(res, dz):
(cos_x: f32[4, 2]{R:x},) = res
dx_: f32[4, 2]{U: x} = dz: f32[4, 8@x] @ y.T: f32[8@x, 2]
dx: f32[4, 2]{U: x} = mul(dx_, cos_x)
return dx
```
As you can see `mul` gets 2 inputs `dx_` and `cos_x` (which is the residual). Since mul is bilinear, if one of the inputs is unreduced along a mesh axis, then the other input **has to be reduced** along the same mesh axis.
There's another thing to note: After grad, the type of `dx` will be `f32[4, 2]{U: x}` which makes sense because the type of `x` was `f32[4, 2]{R: x}`!
PiperOrigin-RevId: 829096082