[Dynamo] Add correct guards for tracable tensor subclasses (#119110)
Fixes #118896
```
(pt) [ybliang@devgpu002.ash8 ~/local/pytorch (subclass)]$ TORCH_LOGS="+guards" python test/dynamo/test_subclasses.py -k test_torch_dispatch_subclass_guard_recompile
/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
[2024-02-02 16:43:02,186] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-02-02 16:43:02,186] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['w'], 110557008) # return torch.add(w, 1.0) # ata/users/ybliang/pytorch/test/dynamo/test_subclasses.py:923 in fn
[2024-02-02 16:43:02,187] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['w'].a, '_dynamo_dynamic_indices') == False # return torch.add(w, 1.0) # ata/users/ybliang/pytorch/test/dynamo/test_subclasses.py:923 in fn
[2024-02-02 16:43:02,187] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['w'].b, '_dynamo_dynamic_indices') == False # return torch.add(w, 1.0) # ata/users/ybliang/pytorch/test/dynamo/test_subclasses.py:923 in fn
[2024-02-02 16:43:02,187] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:388 in init_ambient_guards
[2024-02-02 16:43:02,187] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_current_backend(139704947520224) # _dynamo/output_graph.py:394 in init_ambient_guards
[2024-02-02 16:43:02,187] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['w'].a, Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[2, 2], stride=[2, 1]) # return torch.add(w, 1.0) # ata/users/ybliang/pytorch/test/dynamo/test_subclasses.py:923 in fn
[2024-02-02 16:43:02,187] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['w'].b, Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[2, 2], stride=[2, 1]) # return torch.add(w, 1.0) # ata/users/ybliang/pytorch/test/dynamo/test_subclasses.py:923 in fn
[2024-02-02 16:43:02,206] [0/1] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-02-02 16:43:02,207] [0/1] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['w'], '_dynamo_dynamic_indices') == False # return torch.add(w, 1.0) # ata/users/ybliang/pytorch/test/dynamo/test_subclasses.py:923 in fn
[2024-02-02 16:43:02,207] [0/1] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:388 in init_ambient_guards
[2024-02-02 16:43:02,207] [0/1] torch._dynamo.guards.__guards: [DEBUG] ___check_current_backend(139704947520224) # _dynamo/output_graph.py:394 in init_ambient_guards
[2024-02-02 16:43:02,207] [0/1] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['w'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[2, 2], stride=[2, 1]) # return torch.add(w, 1.0) # ata/users/ybliang/pytorch/test/dynamo/test_subclasses.py:923 in fn
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119110
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh, https://github.com/yoyoyocmu