pytorch
389380ff - [reland] Refactor Tensor::to to call a primitive that is not copy_. (#62262)

Commit
3 years ago
[reland] Refactor Tensor::to to call a primitive that is not copy_. (#62262) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62262 Context ------- functorch is unable to vmap(grad(f)) when f contains a .to call. This is because .to (when it is not a no-op) decomposes to .copy_ under grad and the .copy_ is not compatible with vmap. Fix --- The fix for this is to have all Tensor::to variants call a new operator, `_to_copy`, that always copies and is a primitive w.r.t. autograd so that autograd decomposes Tensor::to into a call to `_to_copy`. (This is related to https://github.com/pytorch/pytorch/issues/60956, please let me know if you want to bikeshed the naming). In order to get this done I had to do a bit of refactoring. All of the `::to` implementations now call `to_impl` which may call `_to_copy`. Autograd codegen changes ------------------------ The second thing I had to do was modify the autograd codegen. Right now, autograd assumes that every output is either statically known to be differentiable or not differentiable at codegen time. `_to_copy` is a little special because its differentiability depends on the output dtype. e.g. `torch.randn(3, requires_grad=True).to(torch.long)` is non differentiable. To get this to work: - I changed how `output_differentiability` in derivatives.yaml work. - output_differentiability can now accept "conditions" for each of the output arguments. A "condition" is some C++ code. - We currently only support `output_differentiability` with conditions if there is a single output. This is for convenience and can be changed in the future. - I added a new `output_differentiability_conditions` field to DifferentiabilityInfo. This gets populated in load_derivatives.yaml - forward-mode and reverse-mode AD take `output_differentiability_conditions` into account. Here's how the generated code for `VariableType::_to_copy` [looks like](https://gist.github.com/zou3519/93462df4bda1837acee345205b7cc849) No other autogenerated code gets modified by this PR. Performance benchmarking ------------------------ - I benchmarked [three cases that demonstrate overhead](https://gist.github.com/zou3519/5b6985e6906b80eec5a0dd94ed5b6a1a). - Case A: No-op .to(). Instruction count went from 50223 to 25623. I have no clue why but this is a good thing. - Case B: not-no-op .to(). Instruction count went from 665291 to 671961. This is expected; `_to_copy` adds an additional dispatch. - Case C: not-no-op .to() forward pass and backward pass. Instruction count went from 4022841 to 4030057. This PR adds an additional dispatch to .to() (so there should be one additional dispatch in the forward pass) so this number looks reasonable. Test Plan --------- - test_torch.py has a test_to - test_cuda.py has test_to* - test_autograd has tests (test_type_conversions) that exercise the reverse-mode path - test_ops.py has some tests (like log_softmax) that exercise the reverse-mode and forward-mode AD path. - test_quantization, test_namedtensor all exercise tensor.to as well. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D29934998 Pulled By: zou3519 fbshipit-source-id: 820069acd66fd5af97b98f42edfca68572c9eb1c
Author
Parents
Loading