pytorch
00d432a1 - Remove optional for veiw_fn during View Tracking (#50067)

Commit
3 years ago
Remove optional for veiw_fn during View Tracking (#50067) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50067 Fixes #49257 Using the `Callgrind` to test the performance. ```python import torch import timeit from torch.utils.benchmark import Timer timer = Timer("x.view({100, 5, 20});", setup="torch::Tensor x = torch::ones({10, 10, 100});", language="c++", timer=timeit.default_timer) res = timer.collect_callgrind(number=10) ``` ### Nightly ```python torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f7949138c40> x.view({100, 5, 20}); setup: torch::Tensor x = torch::ones({10, 10, 100}); All Noisy symbols removed Instructions: 42310 42310 Baseline: 0 0 10 runs per measurement, 1 thread Warning: PyTorch was not built with debug symbols. Source information may be limited. Rebuild with REL_WITH_DEB_INFO=1 for more detailed results. ``` ### Current ```python <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f78f271a580> x.view({100, 5, 20}); setup: torch::Tensor x = torch::ones({10, 10, 100}); All Noisy symbols removed Instructions: 42480 42480 Baseline: 0 0 10 runs per measurement, 1 thread Warning: PyTorch was not built with debug symbols. Source information may be limited. Rebuild with REL_WITH_DEB_INFO=1 for more detailed results. ``` ### Compare There are 170 instructions reduced ```python torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f7941b7a7c0> 970 ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool) 240 ???:torch::autograd::ViewInfo::~ViewInfo() 180 ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, std::function<at::Tensor (at::Tensor const&)>) 130 ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta, bool) 105 /tmp/benchmark_utils_jit_build_69e2f1710544485588feeca0719a3a57/timer_cpp_4435526292782672407/timer_src.cpp:main 100 ???:std::function<at::Tensor (at::Tensor const&)>::function(std::function<at::Tensor (at::Tensor const&)> const&) 70 ???:torch::autograd::DifferentiableViewMeta::~DifferentiableViewMeta() 70 ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta) -100 ???:c10::optional_base<torch::autograd::ViewInfo>::optional_base(c10::optional_base<torch::autograd::ViewInfo>&&) -105 /tmp/benchmark_utils_jit_build_2e75f38b553e42eba00523a86ad9aa05/timer_cpp_3360771523810516633/timer_src.cpp:main -120 ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, c10::optional<std::function<at::Tensor (at::Tensor const&)> >) -210 ???:c10::optional_base<std::function<at::Tensor (at::Tensor const&)> >::~optional_base() -240 ???:c10::optional_base<torch::autograd::ViewInfo>::~optional_base() -920 ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, c10::optional<std::function<at::Tensor (at::Tensor const&)> >, torch::autograd::CreationMeta, bool) ``` Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D25900495 Pulled By: ejguan fbshipit-source-id: dedd30e69db6b48601a18ae98d6b28faeae30d90
Author
Parents
Loading