pytorch
f41db99a - Add simple correctness check for native MHA (#72941)

Commit
2 years ago
Add simple correctness check for native MHA (#72941) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72941 Simple test for MHA, use cos similarity as metric since scaling generate mismatch. Cuda is validated, CPU fix a following (We can land this with onlyCuda flag, and remove it once CPU is also done) Test Plan: For cuda: buck build mode/opt -c fbcode.enable_gpu_sections=true caffe2/test:nn && buck-out/gen/caffe2/test/nn\#binary.par -r test_native_multihead_attention_cuda_float32 2>&1 | pastry Reviewed By: swolchok Differential Revision: D33906921 fbshipit-source-id: ad447401eb7002f22ed533d620a6b544524b3f58 (cherry picked from commit 45b778da27598c1d4763aa22843b48a88fa90373)
Author
Committer
Parents
Loading