Refactor tensor_new.cpp to use TensorOptions instead of DispatchKey (#54034)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54034
Fixes #53544
I had to touch a bunch of lines but the refactoring was fairly
mechanical. Here's how it works.
The basic concept behind this PR is that tensor_new.cpp was previously
abusing DispatchKey when it actually meant TensorOptions. The provided
DispatchKey argument to most of the constructor functions typically
comes from torch::tensors::get_default_dispatch_key(); it doesn't
really make sense for people to set the default dispatch key, but
this got grandfathered in due to the old API set_default_tensor_type
(where the "Type" concept got refactored into "DispatchKey" concept
over time). See also #53124. But the upshot is that, semantically,
what we refer to as the default dispatch key really is more like
torch.set_default_tensor_type(torch.Tensor) versus
torch.set_default_tensor_type(torch.cuda.Tensor): clearly the user
wants to do something about *construction* of the tensor, and
TensorOptions captures that exactly.
So, how exactly to translate from one to the other?
- Sources (things that used to PRODUCE DispatchKey)
- Most top level functions take a DispatchKey as their argument. I
use the new function dispatchKeyToTensorOptions to convert it into
a TensorOptions
- typeIdWithDefault now produces a TensorOptions (probably could do
with a rename, though I didn't)
- Sinks (things that used to CONSUME DispatchKey)
- Previously, the function options() was typically used to convert the
DispatchKey into a TensorOptions. Now its replacement build_options
just takes a TensorOptions and sets some extra fields on it.
Irritatingly, I can't just replace
`build_options(options, scalar_type, device)` with
`options.dtype(scalar_type).device(device)` because the semantics
are slightly different: if device is nullopt, we should preserve
the usage of the device specified in options (what options.device()
does is overwrite the device unconditionally; e.g., if device is
nullopt, unset device from options)
- The other major sink for DispatchKey was `internal_new_from_data`,
but it turns out it only really extracts the device type from
the dispatch key. Now it just pulls out the device from
TensorOptions.
- To actually do the translation of DispatchKey to TensorOptions, I
introduce new functions dispatchKeyToLayout (replicating
layout_from_backend--there are still a few uses of this function
so I couldn't delete it) and dispatchKeyToDeviceType (replacing
computeDeviceType)
- In all internal functions, whenever DispatchKey is taken as an argument,
I instead take TensorOptions as an argument, and pass it along.
- Anywhere `legacyExtractDispatchKey(other.key_set())` equality was
previously used, I now do `other.options().type_equal()`, which
is the intended BC for doing "backend to backend" comparisons
- There are a few places in the sparse constructors where we allocated
a tensor for values, and then read out the dispatch key from the
result to allocate the keys. As best as I can tell, this is totally
equivalent to just passing in the options to both values and indices
(the only difference is dtype, which is captured via a separate
argument)
This refactor doesn't really go far enough: for example, there are now
functions that take both TensorOptions and ScalarType, when really
the TensorOptions can capture this all. I kept it solely just
s/DispatchKey/TensorOptions/ to reduce the number of possible bugs;
also, a lot of this will be mooted by a proper fix to #53124.
Even with this limited refactor, the payoff is sweet. I can delete:
- backendToCPU
- backendToXPU
- backendToCUDA
- backendToHIP
- backendToBackendOfDeviceType
The reason I can do this is because I can simply overwrite layout in TensorOptions
to do the conversion, rather than having to type out each backend case
explicitly.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D27109509
Pulled By: ezyang
fbshipit-source-id: 91d16cfbc390127770362ac04fb43f7e070077e9