pytorch
f5d07808 - [PyTorch] MHA: fix dim_per_head / V bug (#72459)

Commit
3 years ago
[PyTorch] MHA: fix dim_per_head / V bug (#72459) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72459 This was pointed out in a comment on the original diff, but not fixed. ghstack-source-id: 149067331 Test Plan: cosine similarity with the existing MHA impl result on CPU + float32 goes from 0.2457 to 0.5097 Reviewed By: zrphercule Differential Revision: D33987869 fbshipit-source-id: b560ade85f577e83bcaf5b37da2e89d8646d5909 (cherry picked from commit 47511a2138a35b5e71ef3562a6e93cb59d965ab2)
Author
Committer
Parents
Loading