Register some backwards functions as operators (#44052)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44052
Summary
=======
This PR registers the following backwards functions as operators:
- slice_backward
- select_backward
- gather_backward
- index_select_backward (the backward function for index_select)
- select_index_backward (prevously known as index_select_backward, but is actually the backward function for max.dim, min.dim, etc)
In the future, I'd like to register more backward functions as operators
so that we can write batching rules for the backward functions. Batching
rules for backward functions makes it so that we can compute batched
gradients.
Motivation
==========
The rationale behind this PR is that a lot of backwards functions (27 in total)
are incompatible with BatchedTensor due to using in-place operations.
Sometimes we can allow the in-place operations, but other times we can't.
For example, consider select_backward:
```
Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.select(dim, index).copy_(grad);
return grad_input;
}
```
and consider the following code:
```
x = torch.randn(5, requires_grad=True)
def select_grad(v):
torch.autograd.grad(x[0], x, v)
vs = torch.randn(B0)
batched_grads = vmap(select_grad)(vs)
```
For the batched gradient use case, `grad` is a BatchedTensor.
The physical version of `grad` has size `(B0,)`.
However, select_backward creates a `grad_input` of shape `(5)`, and
tries to copy `grad` to a slice of it.
Other approaches
================
I've considered the following:
- register select_backward as an operator (this PR)
- have a branch inside select_backward for if `grad` is batched.
- this is OK, but what if we have more tensor extensions that want to override this?
- modify select_backward to work with BatchedTensor, by creating a new operator for the "select + copy_ behavior".
- select + copy_ isn't used elsewhere in derivative formulas so this doesn't seem useful
Test Plan
=========
- `pytest test/test_autograd.py -v`
- Registering backward functions may impact performance. I benchmarked
select_backward to see if registering it as an operator led to any noticable
performance overheads: https://gist.github.com/zou3519/56d6cb53775649047b0e66de6f0007dc.
The TL;DR is that the overhead is pretty minimal.
Test Plan: Imported from OSS
Reviewed By: ezyang, fbhuba
Differential Revision: D23481183
Pulled By: zou3519
fbshipit-source-id: 125af62eb95824626dc83d06bbc513262ee27350