masked_scatter: fuse mask count check into one kernel (#66871)
Summary:
This saves 1 kernel launch, 7 dispatcher calls, 3 `TensorImpl` allocations and 1 CUDA memory allocation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66871
Reviewed By: gchanan
Differential Revision: D31763713
Pulled By: ngimel
fbshipit-source-id: b0d2f9415b7fd013fb4e7d68ade6e38a58f5b153