[Mosaic] Support multiple non-contracting dims if they are collapsable.
With reshape, we can support the following two cases.
1. [batch_dims, non_contracting_dims, contracting_dims] -> [batch_dims, prod(non_contracting_dims), contracting_dims] or
2. [batch_dims, contracting_dims, non_contracting_dims] -> [batch_dims, contracting_dims, prod(non_contracting_dims)].
I'm reluctant to change apply vector layout and want to keep it to only handle 2D matrix.
PiperOrigin-RevId: 783816604