set num_warps to at least 4 (#97950)
To avoid IMAs in https://gist.github.com/ngimel/25e81c996d9c8c652d97e33cc9c7d5f4
This is not a general fix (e.g. if inputs were a bit larger, num_warps would naturally be 4, and we could still have spills and hit ptxas bugs), but will do for now.
Longer term, we should check spills in kernels we generate and recompile with more warps if there are spills.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97950
Approved by: https://github.com/bertmaher
Author
Natalia Gimelshein