Refactor TensorIterator to do allocations via MetaBase::set_output (#48659)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48659
Detailed RFC at
https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator
What this diff does:
* Refactor allocation of outputs in TensorIterator into a call to a single function TensorIterator::set_output. This nicely centralizes restriding logic and mostly eliminates the need for a separate named tensor propagation pass. The one exception is for inplace operations (`add_`), where previously we never actually call `set_output` when we determine resizing is not necessary; there's an extra propagate names in `allocate_or_resize_outputs` to handle this case (I audited all other `set_output` sites and found that we always hit this path in that situation). Although hypothetically this could cause problems for structured kernels (which require a `set_output` call in all cases), this codepath is irrelevant for structured kernels as a TensorIterator will never be constructed with an explicit out argument (remember, structured kernels handle out/functional/inplace variants). There's also a tricky case in `compute_types`; check the comments there for more details.
* Split TensorIterator into a TensorIteratorBase, which contains most of the logic but doesn't define `set_output`. A decent chunk of the diff is just the mechanical rename of TensorIterator to TensorIteratorBase. However, there are a few cases where we create fresh TensorIterator objects from another TensorIterator. In those cases, we always construct a fresh TensorIterator (rather than preserving the subclass of TensorIteratorBase that induced this construction). This makes sense, because a structured function class will contain metadata that isn't relevant for these downstream uses. This is done by *intentionally* permitting object slicing with the `TensorIterator(const TensorIteratorBase&)` constructor.
* Introduce a new `MetaBase` class which contains the canonical virtual method definition for `set_output`. This will allow structured classes to make use of it directly without going through TensorIterator (not in this PR).
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D25261844
Pulled By: ezyang
fbshipit-source-id: 34a9830cccbc07eaaf7c4f75114cd00953e3db7d