[PyTorch] Fix MHA grain size computation (#72463)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72463
maxing with 1 makes a lot more sense to me than minning with 1, but I have no idea what I'm doing.
ghstack-source-id: 149067332
Test Plan: CI
Reviewed By: zrphercule
Differential Revision: D33990633
fbshipit-source-id: c706148c357473c929020f5dc65cc5050611af8f
(cherry picked from commit 2adf3be11a59387bbab7fc73da236ab5fff7be9c)