pytorch
74d81080 - Use new_zeros in evenly_distribute_backward (#46674)

Commit
4 years ago
Use new_zeros in evenly_distribute_backward (#46674) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46674 Summary ------- This adds batched gradient support (i.e., vmap through the gradient formulas) for Tensor.max(), Tensor.min(), Tensor.median() that have evenly_distribute_backward as their backward formula. Previously, the plan was to register incompatible gradient formulas as backward operators (see #44052). However, it turns out that we can just use `new_zeros` to get around some incompatible gradient formulas (see next section for discussion). Context: the vmap+inplace problem --------------------------------- A lot of backwards functions are incompatible with BatchedTensor due to using in-place operations. Sometimes we can allow the in-place operations, but other times we can't. For example, consider select_backward: ``` Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { auto grad_input = at::zeros(input_sizes, grad.options()); grad_input.select(dim, index).copy_(grad); return grad_input; } ``` and consider the following code: ``` x = torch.randn(5, requires_grad=True) def select_grad(v): torch.autograd.grad(x[0], x, v) vs = torch.randn(B0) batched_grads = vmap(select_grad)(vs) ``` For the batched gradient use case, grad is a BatchedTensor. The physical version of grad has size (B0,). However, select_backward creates a grad_input of shape (5), and tries to copy grad to a slice of it. Up until now, the proposal to handle this has been to register these backward formulas as operators so that vmap doesn’t actually see the `copy_` calls (see #44052). However, it turns out we can actually just use `new_zeros` to construct a new Tensor that has the same "batched-ness" as grad: ``` auto grad_input = grad.new_zeros(input_sizes); grad_input.select(dim, index).copy_(grad); ``` We should use this for simple backward functions. For more complicated backward functions where this solution doesn't work, we should register those as operators. Alternatives ------------ Option 2: Register `evenly_distribute_backward` as an operator and have the vmap fallback run it in a loop. - This requires more LOC changes. - Furthermore, we'd have to write an efficient batching rule for `evenly_distribute_backward` in the future. - If we use `new_zeros` instead, we don't need to write an efficient batching rule for `evenly_distribute_backward` as long as the constituents of `evenly_distributed_backward` have efficient batching rules. Option 3: Have factory functions perform differently if they are called inside vmap. - For example, `at::zeros(3, 5)` could return a Tensor of shape `(B0, B1, 3, 5)` if we are vmapping over two dimensions with size B0 and B1. This requires maintaining some global and/or thread-local state about the size of the dims being vmapped over which can be tricky. And more... Future ------ - I will undo some of the work I’ve done in the past to move backward functions to being operators (#44052, #44408). The simpler backward functions (like select backward) can just use Tensor.new_zeros. I apologize for the thrashing. - Include a NOTE about the vmap+inplace problem somewhere in the codebase. I don't have a good idea of where to put it at the moment. Test Plan --------- - New tests Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D24456781 Pulled By: zou3519 fbshipit-source-id: 9c6c8ee2cb1a4e25afd779bdf0bdf5ab76b9bc20
Author
Parents
Loading