[MPS] Fix the crash in max_out() caused by cached key conflict (#91520)
The shape of input and indices tensors were missing in the cached key
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91520
Approved by: https://github.com/DenisVieriu97, https://github.com/kulinseth, https://github.com/malfet