pytorch
2a559841 - [generate_vmap_rule] reductify_leaf helper function (#90965)

Commit
2 years ago
[generate_vmap_rule] reductify_leaf helper function (#90965) As seen in https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit `reductify_leaf(grad_input, ...)` is a helper function that processes a single grad_input Tensor. The reason why we need it is: - the grad_input has some optional bdim - the input has some optional bdim - if these are different, we need to coerce the grad_input into having the same shape as the input, either by reducing or expanding the grad_input. Note that there is a special case in autograd that the user is allowed to return a grad_input Tensor that is an expanded version of the original input tensor. In this case, autograd automatically reduces grad_input to the same shape as the input. Unfortunately this logic doesn't work when bdims are involved, so we manually handle it in `reductify_leaf`. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/90965 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading