[inductor] Recursivly unwrap_storage_for_input when convert_to_reinterpret_view fails (#119867)
Summary:
When, during `ExternKernel.realize_input` call, underlying `ExternKernel.convert_to_reinterpret_view` fails, we currently fall back to `cls.copy_input` here:
https://github.com/pytorch/pytorch/blob/31e59766e7e7b51e8dddd4a6967891ac01f4d37b/torch/_inductor/ir.py#L3805-L3816
This creates a `TensorBox(StorageBox(...))` wrapped output, which causes a problem for this assertion:
https://github.com/pytorch/pytorch/blob/31e59766e7e7b51e8dddd4a6967891ac01f4d37b/torch/_inductor/ir.py#L3479
Here we add a special case handling for this to unwrap `x` recursively.
Test Plan:
This local repro:
```
torch.compile()
def f(a, b, mat1, mat2):
bias = torch.bmm(a + 3.14, b).permute(0, 2, 1).reshape(3992, -1)
return torch.addmm(bias, mat1, mat2)
f(
torch.randn(3992, 20, 40).cuda(),
torch.randn(3992, 40, 192).cuda(),
torch.empty(3992, 1024).cuda(),
torch.empty(1024, 3840).cuda(),
)
```
with this line:
https://github.com/pytorch/pytorch/blob/690f54b0f5fa911ba9f7cb6f2ef9719ec765d2d2/torch/_inductor/fx_passes/post_grad.py#L650
changed to `if cond(*args, **kwargs):` fails before and succeeds after this PR.
Differential Revision: D53743146
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119867
Approved by: https://github.com/xw285cornell