Fix dispatching of backwards kernel for ROCm. (#22125)
Summary:
Use WARP_SIZE consistently also for the dispatch dimensions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22125
Differential Revision: D15966661
Pulled By: bddppq
fbshipit-source-id: 93eb663e01aff3b49474504a2f96f060919edf0c