batched grad for binary_cross_entropy, symeig (#48057)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48057
This PR fixes batched grad computation for:
- binary_cross_entropy (i.e., vmap through binary_cross_entropy_double_backward)
- symeig (i.e. vmap through symeig_backward)
It was previously impossible to vmap through those functions because
they use in-place operations in a vmap-incompatible way.
See note at
https://github.com/pytorch/pytorch/blob/233192be7334ab1dda217472f67d1de5bcb6685b/aten/src/ATen/BatchedFallback.cpp#L117-L122
for what it means for an in-place operation to be vmap-incompatible.
This PR adds a check: if the in-place operations in e.g. symeig are
vmap-incompatible and we are inside of a vmap, then we do the
out-of-place variant of the operation. Ditto for binary_cross_entropy.
This is to avoid code duplication: the alternative would be to register
the backward formula as an operator and change just those lines to be
out-of-place!
This PR also adds some general guidelines for what to do if an in-place
operation is vmap-incompatible.
General guidelines
------------------
If an in-place operation used in a backward formula is vmap-incompatible,
then as developers we have the following options:
- If the in-place operation directly followed the creation of a tensor with
a factory function like at::zeros(...), we should replace the factory with a
corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
propagates the batch dims to the resulting tensor.
For example:
Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
After: grad.new_zeros(input.sizes()).copy_(grad)
- If the in-place operation followed some sequence of operations, if the
we want to be able to vmap over the backward formula as-is (this is
usually the case for simple (<15loc) backward formulas), then use
inplace_is_vmap_compatible to guard the operation. For example:
c = a * b
Before: c.mul_(grad)
After: c = inplace_is_vmap_compatible(c, grad) ? c.mul_(grad) : c * grad
- If we don't want to vmap directly over the backward formula (e.g., if the
backward formula is too complicated or has a lot of vmap-incompatible
operations, then register the backward formula as an operator and eventually
write a batching rule for it.
Test Plan
---------
New tests
Test Plan: Imported from OSS
Reviewed By: zhangguanheng66
Differential Revision: D25069525
Pulled By: zou3519
fbshipit-source-id: e0dfeb5a812f35b7579fc6ecf7252bf31ce0d790