pytorch
03f3a033 - add slice/select/diagonal_scatter variants as primitive ops (#64430)

Commit
4 years ago
add slice/select/diagonal_scatter variants as primitive ops (#64430) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64430 The functionalization pass needs `{view}_scatter` versions of the slice/select/diagonal ops in order to correctly propagate mutations from a view to its base. On top of that, the implementations need to be primitive w.r.t. autograd, because they look something like `...slice().copy_()`, and the functionalization pass can't use views + mutations inside of it's own alias-removal machinery! I added some basic tests that I tried to base off of existing tests for views (particularly around testing the derivative formulas), but I'm wondering if I should add something more comprehensive. Also, as_strided fits into this category - the functionalization pass will need an `as_strided_scatter` op that's primitive w.r.t. autograd. I didn't add it for now, because it'll involve duplicating a bunch of logic from the current `as_strided_backward()` function, and also writing a derivative formula that I wasn't sure how to write :) Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942092 Pulled By: bdhirsh fbshipit-source-id: c702a57c2748a7c771c14e4bcc3e996b48fcc4c8
Author
Parents
Loading