Removed overhead from reshape() call if tensor doesn't need to be changed (#61466)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61466
## Goal
Per #55126 the performance of `reshape` is worse than `alias` in cases where they are performing the same operation (i.e. where reshape is returning a view) because `reshape` delegates to `view` and duplicates some of the operations (specifically `infer_size_dv` and `computeStride`).
The goal of this pull-request is to reduce or remove the additional overhead that `reshape` has.
### Proposed Implementation
Instead of using `view` we implement a private/internal operator (`_reshape_alias`) that `reshape` dispatches to which skips the relevant checks. This is functionally equivalent to `as_strided` however it is a lot simpler because it's specialized to this use-case, and importantly the `backward` implementation is a lot faster.
Note that we have to dispatch (`reshape` is a composite operator) because `reshape` can return either a view or a copy of the Tensor depending on the parameters, and this complicates implementing a derivative/backward for `reshape`.
### Why not `as_strided`?
Using `as_strided` directly slows down autograd. If we use a custom function equivalent to `_reshape_alias` but with a simpler backward function then `view` has the same performance as `reshape`. If we delegate to `as_strided` it is about 56% slower (and this holds against our custom function).
This is also the reason we make an internal operator named `_reshape_alias` instead of exposing a new operator since this should only be used in the `reshape` case and it is effectively a more limited version of `view`, `alias`, and `as_strided`.
## Benchmarks
In a micro-benchmark for `backward` running:
```cpp
// Setup
at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
// Benchmark loop
// `reshape(-1)` replaced with a call to view(-1) for view baseline
x.pow(4).reshape(-1).mean().backward();
```
I also benchmarked simple operations without gradients using:
```cpp
// Setup
at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
// Benchmark loop
x.reshape(-1) // replaced with a call to view(-1) for view baseline
```
Baselined to `view`:
* Original `reshape`: `+3.3%` (without gradients `+20.8%`)
* Using `as_strided`: `+55.1%` (without gradients `+1.0%`)
* Using custom `_reshape_view`: `-1.0%` (without gradients `+6.2%`)
In absolute terms (note the percentages above were generated comparing between runs/tests rather than to a single baseline):
* Original `view`: `53.66 us` (without gradients `582.78 ns`)
* Original `reshape`: `55.46 us` (without gradients `704.24 ns`)
* Using `as_strided`: `83.24 us` (without gradients `576.49 ns`)
* Using custom `_reshape_view`: `53.13 us` (without gradients `536.01 ns`)
Note that these benchmarks perform a backwards operation as well. When compared without using gradient computation at all the performance differneces are more pronounced as this takes up more of the time.
### Original performance
<details>
<summary>Benchmark results</summary>
```
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e4d393160>
x.pow(4).view(-1).mean().backward();
setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
Median: 53.66 us
IQR: 2.70 us (52.54 to 55.24)
884 measurements, 100 runs per measurement, 1 thread]
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e2ebd4fa0>
x.pow(4).reshape(-1).mean().backward();
setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
Median: 55.46 us
IQR: 2.61 us (54.39 to 57.01)
889 measurements, 100 runs per measurement, 1 thread]
2276116
2286256
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f0e5b2e3e20>
2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&)
1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>)
1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>)
1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&)
980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&)
720 ???:__tls_get_addr
520 ???:at::shouldRunRecordFunction(bool*)
520 ???:__memcpy_avx_unaligned_erms
200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
100 ???:c10::TensorImpl::strides() const
100 ???:c10::TensorImpl::sizes() const
100 ???:at::(anonymous namespace)::manager()
77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_7815557938202456331/timer_src.cpp:main
40 ???:c10::TensorImpl::numel() const
-77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_8055217880649990171/timer_src.cpp:main
-260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>)
Total: 10140
```
```
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f850dd66c10>
x.view(-1);
setup: at::Tensor x=torch::empty({2,2});
Median: 582.78 ns
IQR: 33.80 ns (573.80 to 607.61)
833 measurements, 10000 runs per measurement, 1 thread]
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f850de31e20>
x.reshape(-1);
setup: at::Tensor x=torch::empty({2,2});
Median: 704.24 ns
IQR: 24.42 ns (697.20 to 721.62)
679 measurements, 10000 runs per measurement, 1 thread]
56896
67036
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f84e1930bb0>
2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&)
1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>)
1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>)
1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&)
980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&)
720 ???:__tls_get_addr
520 ???:at::shouldRunRecordFunction(bool*)
520 ???:__memcpy_avx_unaligned_erms
200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
100 ???:c10::TensorImpl::strides() const
100 ???:c10::TensorImpl::sizes() const
100 ???:at::(anonymous namespace)::manager()
76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_547407365342278353/timer_src.cpp:main
40 ???:c10::TensorImpl::numel() const
-76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_3457873755756181226/timer_src.cpp:main
-260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>)
Total: 10140
```
</details>
### Using `as_strided`
<details>
<summary>Benchmark results</summary>
```
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f8b13bb5b50>
x.pow(4).view(-1).mean().backward();
setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
Median: 53.37 us
IQR: 3.15 us (51.73 to 54.88)
936 measurements, 100 runs per measurement, 1 thread]
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f8af55f8490>
x.pow(4).reshape(-1).mean().backward();
setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
Median: 83.24 us
IQR: 4.05 us (81.20 to 85.25)
609 measurements, 100 runs per measurement, 1 thread]
2267916
2525061
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f8af55f8e50>
31930 ???:_int_free
15940 ???:malloc
11595 ???:_int_malloc
10100 ???:torch::autograd::generated::details::as_strided_backward(at::Tensor, at::TensorGeometry, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)
9360 ???:__tls_get_addr
8280 ???:free
8100 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)
4520 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_()
4080 ???:operator new(unsigned long)
...
-780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&)
-1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&)
-1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>)
-1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&)
-2560 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&)
-2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>)
-4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
Total: 257145
```
```
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f93176a0160>
x.view(-1);
setup: at::Tensor x=torch::empty({2,2});
Median: 570.55 ns
IQR: 32.69 ns (552.87 to 585.56)
874 measurements, 10000 runs per measurement, 1 thread]
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f92f8f29490>
x.reshape(-1);
setup: at::Tensor x=torch::empty({2,2});
Median: 576.49 ns
IQR: 37.95 ns (559.51 to 597.46)
861 measurements, 10000 runs per measurement, 1 thread]
56896
58556
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f932556ca60>
2140 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>)
1940 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)
1880 ???:torch::ADInplaceOrView::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)
1720 ???:at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)
1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>)
1400 ???:at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)
1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)'2
1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)
980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&)
...
-620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const
-780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2
-780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&)
-1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>)
-1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&)
-1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>)
Total: 1660
```
</details>
### Using custom function (`_reshape_alias`)
<details>
<summary>Benchmark results</summary>
```
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f16861d6b50>
x.pow(4).view(-1).mean().backward();
setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
Median: 53.50 us
IQR: 2.64 us (52.32 to 54.96)
906 measurements, 100 runs per measurement, 1 thread]
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f1667b2ed60>
x.pow(4).reshape(-1).mean().backward();
setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true));
Median: 53.13 us
IQR: 3.40 us (51.72 to 55.13)
914 measurements, 100 runs per measurement, 1 thread]
2269736
2273236
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f1693f8dc10>
5060 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>)
1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&)
1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>)
1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2
1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
1220 ???:torch::autograd::generated::AliasToShapeBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&)
...
-780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2
-780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&)
-1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&)
-1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>)
-1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&)
-2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>)
-4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
Total: 3500
```
```
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f5287adfb20>
x.view(-1);
setup: at::Tensor x=torch::empty({2,2});
Median: 505.10 ns
IQR: 20.04 ns (500.41 to 520.45)
944 measurements, 10000 runs per measurement, 1 thread]
[<torch.utils.benchmark.utils.common.Measurement object at 0x7f526951b430>
x.reshape(-1);
setup: at::Tensor x=torch::empty({2,2});
Median: 536.01 ns
IQR: 17.81 ns (531.34 to 549.16)
916 measurements, 10000 runs per measurement, 1 thread]
56896
60376
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f5295896c10>
2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>)
1860 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&)
1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>)
1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2
1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)
980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&)
...
-620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const
-780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2
-780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&)
-1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>)
-1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&)
-1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)
-2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>)
Total: 3480
```
</details>
Test Plan: Imported from OSS
Reviewed By: ejguan
Differential Revision: D29792126
Pulled By: laurencer
fbshipit-source-id: f0519b45b65f868aa3e8651679354558bd761dfd