[NNC] Fix a bug in SplitWithMask when splitting multiple times (#45141)
Summary:
When doing a splitWithMask we only mask if the loop extent is not cleanly divide by the split factor. However, the logic does not simplify so any nontrivial loop extents will always cause a mask to be added, e.g. if the loop had been previously split. Unlike splitWithTail, the masks added by splitWithMask are always overhead and we don't have the analysis to optimize them out if they are unnecessary, so it's good to avoid inserting them if we can.
The fix is just to simplify the loop extents before doing the extent calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45141
Reviewed By: ezyang
Differential Revision: D23869170
Pulled By: nickgg
fbshipit-source-id: 44686fd7b802965ca4f5097b0172a41cf837a1f5