fix cosine_similarity (#18168)
Summary:
fixes #18057 according to colesbury 's suggestion. Thanks!
cc: ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18168
Differential Revision: D14520953
Pulled By: ailzhang
fbshipit-source-id: 970e6cfb482d857a81721ec1d0ee4a4df84a0450