jax
aa9e476a - Fix broadcast_in_dim tranpose rules if input is reduced, the ct should be unreduced.

Commit
25 days ago
Fix broadcast_in_dim tranpose rules if input is reduced, the ct should be unreduced. For example: If `f32[8]{R:x} ---broadcast---> f32[2@x, 8]`, then on transpose, we should do a lazy psum such that the cotangent type is `f32[8]{U:x}` PiperOrigin-RevId: 866090525
Author
Parents
Loading