Update ScatterWeightedSum Op (#23087)
Summary:
Update ScatterWeightedSum op when there exists only one weighted X to update slice of Y which is usually the case when the op is used for gradient update. The changes remove the copy overhead and seeing significant operator performance improvement
- 25 - 50% improvment on CUDA based on input configuration
- ~50% improvement on ROCm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23087
Differential Revision: D16385194
Pulled By: bddppq
fbshipit-source-id: 3189e892940fb9c26305269eb0d47479b9b71af0