pytorch
24df3b73 - torch.empty_like and torch.zeros_like raise error if any memory format is provided with sparse input (#43699) (#44058)

Commit
4 years ago
torch.empty_like and torch.zeros_like raise error if any memory format is provided with sparse input (#43699) (#44058) Summary: Fixes https://github.com/pytorch/pytorch/issues/43699 - Changed the order of `TORCH_CHECK` and `if (options.layout() == kSparse && self.is_sparse())` inside `empty_like` method. - [x] Added tests EDIT: More details on that and why we can not take zeros_like approach. Python code : ```python res = torch.zeros_like(input_coalesced, memory_format=torch.preserve_format) ``` is routed to ```c++ // TensorFactories.cpp Tensor zeros_like( const Tensor& self, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) { if (options.layout() == kSparse && self.is_sparse()) { auto res = at::empty({0}, options); // to be resized res.sparse_resize_and_clear_( self.sizes(), self.sparse_dim(), self.dense_dim()); return res; } auto result = at::empty_like(self, options, optional_memory_format); return result.zero_(); } ``` and passed to `if (options.layout() == kSparse && self.is_sparse())` When we call in Python ```python res = torch.empty_like(input_coalesced, memory_format=torch.preserve_format) ``` it is routed to ```c++ Tensor empty_like( const Tensor& self, const TensorOptions& options_, c10::optional<c10::MemoryFormat> optional_memory_format) { TORCH_CHECK( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); TensorOptions options = self.options() .merge_in(options_) .merge_in(TensorOptions().memory_format(optional_memory_format)); TORCH_CHECK( !(options.layout() != kStrided && optional_memory_format.has_value()), "memory format option is only supported by strided tensors"); if (options.layout() == kSparse && self.is_sparse()) { auto result = at::empty({0}, options); // to be resized result.sparse_resize_and_clear_( self.sizes(), self.sparse_dim(), self.dense_dim()); return result; } ``` cc pearu Pull Request resolved: https://github.com/pytorch/pytorch/pull/44058 Reviewed By: albanD Differential Revision: D23672494 Pulled By: mruberry fbshipit-source-id: af232274dd2b516dd6e875fc986e3090fa285658
Author
Parents
Loading