Fix the shape inconsistency of `out` and `elem` tensor (#71065)
Summary:
See bug report https://github.com/pytorch/pytorch/issues/71063
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71065
Reviewed By: anjali411
Differential Revision: D33549921
Pulled By: ejguan
fbshipit-source-id: bc43f5f9a88f7dcd8729d0e0f4b90d20f40b3064