pytorch
dc2e6303 - Optimize PReLU (float32) and enable PReLU BFloat16 support in CPU path (#63634)

Commit
2 years ago
Optimize PReLU (float32) and enable PReLU BFloat16 support in CPU path (#63634) Summary: In this PR, we try to optimize PReLU op in CPU path, and enable BFloat16 support based on the optimized PReLU. The original implementation uses parallel_for to accelerate operation speed, but vectorization is not used. It can be optimized by using TensorIterator, both including parallelization and vectorization. The difference between PReLU and other activation function ops, is that PReLU supports a learnable parameter `weight`. When called without arguments, nn.PReLU() uses a single parameter `weight` across all input channels. If called with nn.PReLU(nChannels), a separate `weight` is used for each input channel. So we cannot simply use TensorIterator because `weight` is different for each input channel. In order to use TensorIterator, `weight` should be broadcasted to `input` shape. And with vectorization and parallel_for, this implementation is much faster than the original one. Another advantage is, don't need to separate `share weights` and `multiple weights` in implementation. We test the performance between the PReLU implementation of public Pytorch and the optimized PReLU in this PR, including fp32/bf16, forward/backward, share weights/multiple weights configurations. bf16 in public Pytorch directly reuses `Vectorized<scalar_t>` for `BFloat16`. Share weights: ![image](https://user-images.githubusercontent.com/61222868/130403002-ef271bee-0cae-460b-b796-46853599c210.png) ![image](https://user-images.githubusercontent.com/61222868/130403028-96753102-bea3-44c2-8656-2526469e0627.png) Multiple weights: ![image](https://user-images.githubusercontent.com/61222868/130403059-a3418eb2-9546-471f-b057-15bc0e46f0d0.png) ![image](https://user-images.githubusercontent.com/61222868/130403070-8c620db9-f354-4ddd-b5d5-4557e10ea77a.png) cc albanD mruberry jbschlosser walterddr Pull Request resolved: https://github.com/pytorch/pytorch/pull/63634 Reviewed By: yinghai Differential Revision: D34031616 Pulled By: frank-wei fbshipit-source-id: 04e2a0f9e92c658fba7ff56b1010eacb7e8ab44c (cherry picked from commit ed262b15487557720bb0d498f9f2e8fcdba772d9)
Author
Committer
Parents
Loading