pytorch
f324421d - [vulkan] Calculate a 4x4 output tile for each invocation in conv2d_pw (#60760)

Commit
3 years ago
[vulkan] Calculate a 4x4 output tile for each invocation in conv2d_pw (#60760) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60760 A simple optimization to the `conv2d_pw` shader that makes each invocation calculate a 4x4 output tile instead of a single output texel. This results in better memory reuse and subsequently a pretty significant performance win for models similar to the MobileNets. ## Perf improvements from this change On aloha portal devices, in conjunction with the above diff that introduces adaptive work group sizes, benchmark latency of the xirp14b model was reduced from ~8.7 ms to ~6.6 ms. Test Plan: Test vulkan ops: ``` cd ~/fbsource buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:pt_vulkan_api_test_binAndroid\#android-arm64 --show-output adb push buck-out/gen/xplat/caffe2/pt_vulkan_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_api_test adb shell "/data/local/tmp/vulkan_api_test" cd - ``` Reviewed By: IvanKobzarev Differential Revision: D28724590 fbshipit-source-id: e742286b01bf566dc6378677be55409b7faa8cfb
Author
Parents
Loading