re-add option to log all guard check fails (#113585)
Summary:
Followup to https://github.com/pytorch/pytorch/pull/110325 - re-add the `report_all_guard_failures config` as a logging artifact `recompiles_verbose` with the following changes:
- evaluating the check must be wrapped with exception handling because subsequent code parts following the first failure may result in errors if evaluated (e.g. if a guard checks first for size, then tries to index - a guard failure due to insufficient size would result in an index error for the latter check).
- Adding a test for this case
Sample:
```python
import torch
def fn(x):
return torch.rand(x[-1], len(x))
opt_fn = torch.compile(fn)
opt_fn([4, 5, 6])
opt_fn([7, 8])
opt_fn([9])
```
Output (with `TORCH_LOGS="recompiles_verbose"`):
```bash
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] Recompiling function fn in /data/users/williamwen/pytorch/playground5.py:15
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] triggered by the following guard failure(s):
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 0 failures:
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 3
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][0] == 4
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][1] == 5
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] Recompiling function fn in /data/users/williamwen/pytorch/playground5.py:15
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] triggered by the following guard failure(s):
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 0 failures:
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 2
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG]
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 1 failures:
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 3
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][0] == 4
```
X-link: https://github.com/pytorch/pytorch/pull/113585
Approved by: https://github.com/jon-chuang, https://github.com/ezyang
Reviewed By: huydhn
Differential Revision: D51417892
Pulled By: williamwen42
fbshipit-source-id: 2d99b70e3be500419e391067017f79d71ace6b9f