pytorch
3fa2df7c - Support custom autograd functions in C++ (#23572)

Commit
5 years ago
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
Author
mal
Parents
Loading