pytorch
14ff58d4 - [generate_vmap_rule] Delete unused output_shapes (#92024)

Commit
1 year ago
[generate_vmap_rule] Delete unused output_shapes (#92024) We don't actually need `output_shapes` to implement `generate_vmap_rule=True` support for autograd.Function. - We need this in the vjp (backward) case because autograd automatically reduces grad_inputs to inputs and we need to replicate that behavior. In order to replicate that behavior, we recorded the original input shapes so we know how to reduce the grad_input. - There is no such behavior for forward-mode AD, so we don't need to pass an `output_shapes` to reductify. This PR simplifies the API of `reductify` and `reductify_leaf`. Instead of accepting `input_shape_without_bdim` and `allow_expanded_grad`, we now combine these into a single argument, `reduce_to_input_shape_without_bdim`. - if it is None, then we don't do anything - if it is not-None and a shape, then we will reduce the grad to the provided shape. Test Plan: - updated original unittests - wait for test suite Pull Request resolved: https://github.com/pytorch/pytorch/pull/92024 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading