pytorch
c660db20 - Adding vmap support for special bessel functions (#99543)

Commit
1 year ago
Adding vmap support for special bessel functions (#99543) Fixes #91402 ## Description This PR adds vmap support for the following bessel functions under torch.special. * special.bessel_j0 * special.bessel_y0 * special.bessel_j1 * special.modified_bessel_i0 * special.bessel_y1 * special.scaled_modified_bessel_k0 * special.scaled_modified_bessel_k1 * special.modified_bessel_i1 ## Files changed: 1. [aten/src/ATen/functorch/BatchRulesUnaryOps.cpp](https://github.com/pytorch/pytorch/pull/99543/files#diff-a629acd680b2c8639049755617fe89f803cd1001d9936e95d7bf4e388f41c6b8) 2. [test/functorch/test_vmap.py](https://github.com/pytorch/pytorch/compare/main...SiddharthIVEX:pytorch:sid/vmap_special_bessel?expand=1#diff-17b0cd027c7b1ca042fcfe21cc86284d6e58fa46039f7e4297b22b8e02b68fea) ## How was the PR tested? 1. The unit tests under `test_vmap.py` were run and all of them passed. The output is shown below. ``` configfile: pytest.ini plugins: hypothesis-6.71.0, anyio-2.2.0 collected 2003 items / 1981 deselected / 22 selected test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_bessel_j0_cpu_float32 PASSED [ 4%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_bessel_j1_cpu_float32 PASSED [ 9%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_bessel_y0_cpu_float32 PASSED [ 13%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_bessel_y1_cpu_float32 PASSED [ 18%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_modified_bessel_i0_cpu_float32 PASSED [ 22%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_modified_bessel_i1_cpu_float32 PASSED [ 27%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_modified_bessel_k0_cpu_float32 PASSED [ 31%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_modified_bessel_k1_cpu_float32 PASSED [ 36%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_scaled_modified_bessel_k0_cpu_float32 PASSED [ 40%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_scaled_modified_bessel_k1_cpu_float32 PASSED [ 45%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_op_has_batch_rule_special_spherical_bessel_j0_cpu_float32 PASSED [ 50%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_bessel_j0_cpu_float32 PASSED [ 54%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_bessel_j1_cpu_float32 PASSED [ 59%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_bessel_y0_cpu_float32 PASSED [ 63%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_bessel_y1_cpu_float32 PASSED [ 68%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_modified_bessel_i0_cpu_float32 PASSED [ 72%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_modified_bessel_i1_cpu_float32 PASSED [ 77%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_modified_bessel_k0_cpu_float32 PASSED [ 81%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_modified_bessel_k1_cpu_float32 PASSED [ 86%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_scaled_modified_bessel_k0_cpu_float32 PASSED [ 90%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_scaled_modified_bessel_k1_cpu_float32 PASSED [ 95%] test/functorch/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_special_spherical_bessel_j0_cpu_float32 PASSED [100%] ================================================================ 22 passed, 1981 deselected in 18.42s ================================================================ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/99543 Approved by: https://github.com/zou3519
Author
Committer
Parents
Loading