pytorch
2896f81d - Consolidate customization contiguous/sizes policy into unified policy

Commit
2 years ago
Consolidate customization contiguous/sizes policy into unified policy Prior to this PR, we had a mish-mash of ways of getting unconventional sizes/strides behavior: - In OSS (but not in fbcode), some methods are virtual and you can override them directly - There is a is_contiguous policy which is a bitfield tag that lets you toggle is_contiguous to error or hit a virtual method is_contiguous_custom if it is set. Ordinarily is_contiguous() is virtual and you can just override it, but this works EVEN IF is_contiguous() is non-virtual (e.g., in fbcode) - There is also a sizes policy which is the same idea but for sizes This PR unifies these mechanisms, and in doing so, eliminates the maybe virtual/not-virtualness of the methods in question. The primary downside of this change is that it is BC-breaking (but the BC break is very easy to fix!) The new scheme works like this: we have three levels of policy for sizes/strides (order matters). - The Default policy is a conventional dense tensor, where we use all of the built-in fields to directly represent the sizes/strides/numel/contiguity of the tensor, and it is possible to bypass virtual call entirely. - The CustomStrides policy represent tensors which have a custom notion of strides (most typically, that they don't support them), shunting strides() and is_contiguous() to virtual methods strides_custom() and is_contiguous_custom(). This INCLUDES handling for contiguity, since they typically go hand-in-hand (although the situation is murky with batched tensors). The default implementations of these functions raise errors saying the tensor doesn't support them. - The CustomSizes policy represent tensors which have a custom notion of sizes (the two notable examples are nested tensor, which doesn't have a representation of sizes in the conventional form, and XLA/LTC tensor, which synchronizes its sizes with an underlying compiler backend). This shunts sizes(), numel() and dim() (along with everything from strides) to _custom() variants. There is no special policy for erroring; instead, we just do a vcall and expect the virtual method to raise an exception (the performance hit from the vcall doesn't matter because you're about to raise a C++ exception anyway). The default implementations of all overridable functions are available at _default() which is helpful in some situations when you just want to do a "sync" and then run the conventional semantics. This PR could be extended further in two ways but I did not do them due to time constraints: - Ideally, all TENSORIMPL_MAYBE_VIRTUAL would be eliminated from TensorImpl, by using the same policy trick. - set_size and set_stride are still virtual; it's not entirely clear the same trick should be used here though as these methods are deprecated. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/77036 Approved by: https://github.com/bdhirsh
Author
Committer
Parents
  • aten/src/ATen
    • File
      BatchedTensorImpl.cpp
    • File
      BatchedTensorImpl.h
    • File
      NestedTensorImpl.cpp
    • File
      NestedTensorImpl.h
    • File
      OpaqueTensorImpl.h
    • File
      SparseCsrTensorImpl.cpp
    • File
      SparseCsrTensorImpl.h
    • File
      SparseTensorImpl.cpp
    • File
      SparseTensorImpl.h
    • native/vulkan
      • File
        VulkanOpaqueTensorImpl.h
  • c10/core
    • File
      TensorImpl.cpp
    • File
      TensorImpl.h
    • File
      UndefinedTensorImpl.cpp
    • File
      UndefinedTensorImpl.h
  • torch/csrc/lazy/core
    • File
      tensor_impl.cpp
    • File
      tensor_impl.h