pytorch
690c2a70 - masked_scatter: fuse mask count check into one kernel (#66871)

Commit
3 years ago
masked_scatter: fuse mask count check into one kernel (#66871) Summary: This saves 1 kernel launch, 7 dispatcher calls, 3 `TensorImpl` allocations and 1 CUDA memory allocation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/66871 Reviewed By: gchanan Differential Revision: D31763713 Pulled By: ngimel fbshipit-source-id: b0d2f9415b7fd013fb4e7d68ade6e38a58f5b153
Author
Parents
Loading