pytorch
484dd400 - Implement PReLU in a compositional way (#91238)

Commit
2 years ago
Implement PReLU in a compositional way (#91238) The PReLU implementation was all over the place. This lead to a number of bugs like https://github.com/pytorch/pytorch/issues/68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap and many other functorch-related issues. - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes https://github.com/pytorch/pytorch/issues/68760 Fixes https://github.com/pytorch/pytorch/issues/89895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91238 Approved by: https://github.com/kshitij12345, https://github.com/jbschlosser, https://github.com/albanD
Author
Committer
Parents
Loading