[inductor] skip triton.Config that spills (#99385)
TLDR, I did a quick study of register spill in max-autotune and coordesc descent tuning. The conclusion is for the pointwise/reduction kernels, register spill is rare in inductor (which means the configs we consider are relatively reasonable), but it indeed happens sometimes. TBH, this PR does not gonna help reducing compilation time for max-autotune/coordinate descent tuning much because register spilling is very rare. But this PR only contains 2 lines of significant code change, so I guess it's still good to merge it considering ROI and code complexity.
# Register Spill in Max-Autotuner
I ran command
```
rm -rf /tmp/torchinductor_shunting_tmp && time TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting_tmp python -u benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only ${MODEL} --disable-cudagraphs --training 2>&1 | tee /tmp/mylog
```
and then analyze the log.
$ cat /tmp/mylog | grep 'nspill' | wc -l
will show the total number of triton.Config's we benchmark;
$ cat /tmp/mylog | grep 'nspill' | grep -v 'nspill 0'
will show the number of triton.Config's that spill registers.
Checked 5 models
- hf_Bert 0 spills
- resnet50: 2 out of 199 triton.Config's spill. For the 2 configs that spill, they are suboptimal according to the log: https://gist.github.com/shunting314/7ea30a9dafad7156919a99df5feba0ee
- timm_vision_transformer: 2/77 spills. The spilled configs are again sub-optimal: https://gist.github.com/shunting314/a48cbcfb14a07c0b84555e2cf7154852
- BERT_pytorch: 0/123 spills
- timm_resnest 0/255 spills
# Register Spill in Coordinate Descent Tuner
I ran command
```
rm -rf /tmp/torchinductor_shunting_tmp && time TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting_tmp TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_PERSISTENT_REDUCTIONS=0 python -u benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only ${MODEL} --disable-cudagraphs --training 2>&1 | tee /tmp/mylog
```
and then analyze the log.
$ cat /tmp/mylog | grep COORDESC | wc -l
shows the total number of configs considered by the coordinate descent tuner
$ cat /tmp/mylog | grep COORDESC | grep -v 'nspill 0'
shows the ones that spill.
Checked 3 models
- hf_Bert (log https://gist.github.com/shunting314/bd943887e77609c7c8b323fe3f554c85 )
0/525 spills
- resnet50: 0/783 spills
- timm_vision_transformer: 2/380 (log https://gist.github.com/shunting314/6231f06c1398e0cddb2a96bf52389c78 )
the 2 spilled one are sub-optimal
# Ignore Spilled Config
With this PR, I run test tests for timm_vision_transformer and can see all 4 spilled configs (2 for max-autotune and 2 for coordinate descent tuner according to the study above) are skipped for benchmarking:
```
[2023-04-18 00:03:37,291] torch._inductor.triton_heuristics: [DEBUG] Skip config XBLOCK: 16, YBLOCK: 512, num_warps: 8, num_stages: 1 because of register spilling: 6
[2023-04-18 00:04:50,523] torch._inductor.triton_heuristics: [DEBUG] Skip config XBLOCK: 64, RBLOCK: 64, num_warps: 8, num_stages: 1 because of register spilling: 626
[2023-04-18 00:04:50,523] torch._inductor.triton_heuristics: [DEBUG] Skip config XBLOCK: 8, RBLOCK: 512, num_warps: 8, num_stages: 1 because of register spilling: 778
[2023-04-18 00:05:47,170] torch._inductor.triton_heuristics: [DEBUG] Skip config XBLOCK: 1, num_warps: 2, num_stages: 1 because of register spilling: 4
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99385
Approved by: https://github.com/jansel