pytorch
591b75bf - Redo how custom/python_custom methods on TensorImpl work (#84641)

Commit
2 years ago
Redo how custom/python_custom methods on TensorImpl work (#84641) A longstanding confusion in the implementation of fake tensor and proxy tensor is what to do about torch.ops.aten.sym_sizes and related calls. In particular, when you have a tensor that (1) has symbolic shapes and (2) has a `__torch_dispatch__` call, previously, you would always get `__torch_dispatch__` calls for sizes/strides query, *even if you didn't request it* via the dispatch kwargs in `make_wrapper_subclass`. The reason for this is because we were previously mixing several concepts: "I want to dispatch to Python", "I want to call a virtual method" and "I have dynamic shapes". A single boolean variable controlled all of these things, and so it was not possible to understand inside TensorImpl what the user had actually originally requested. In this PR, we track each of these concepts individually so that we can preserve user intent. Then, we combine these into a single "policy" variable that controls whether or not we can use the fastpath or not. For the policy to trigger, we only need one of the exceptional cases to be true. Billing of changes: * Rename `set_sizes_strides_policy` to `set_custom_sizes_strides`; in general, you cannot DIRECTLY set policy; you have to indirectly set it by the public functions. * Some helpers for sizes and strides, since it's more complicated (as it is an enum, rather than just bools as is the case for device and layout). `matches_python_custom` is used to test the Python dispatch user ask. `matches_policy` does the policy test (only used in the user facing functions.) * I reorged the accessor methods so that they are more logical. This makes the diff bad, so I recommend reading the final code directly. * The default custom implementations now more reliably call their default() implementations * As bonus refactor, I devirtualized some functions that don't need to be virtual * `set_sym_sizes_and_strides` is renamed to `set_sizes_and_strides` to make it easier to use in template contexts; it optionally takes a storage offset now so you can set all three values at the same time. If you use the SymInt overload but there are no symbolic integers, we give you a normal resize. * This adds `sym_storage_offset` since we had that in the symbolic shapes branch and there's no reason not to put it in (and it reduces merge conflicts) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84641 Approved by: https://github.com/wconstab
Author
Committer
Parents
Loading