pytorch
370310be - batched grad for binary_cross_entropy, symeig (#48057)

Commit
4 years ago
batched grad for binary_cross_entropy, symeig (#48057) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48057 This PR fixes batched grad computation for: - binary_cross_entropy (i.e., vmap through binary_cross_entropy_double_backward) - symeig (i.e. vmap through symeig_backward) It was previously impossible to vmap through those functions because they use in-place operations in a vmap-incompatible way. See note at https://github.com/pytorch/pytorch/blob/233192be7334ab1dda217472f67d1de5bcb6685b/aten/src/ATen/BatchedFallback.cpp#L117-L122 for what it means for an in-place operation to be vmap-incompatible. This PR adds a check: if the in-place operations in e.g. symeig are vmap-incompatible and we are inside of a vmap, then we do the out-of-place variant of the operation. Ditto for binary_cross_entropy. This is to avoid code duplication: the alternative would be to register the backward formula as an operator and change just those lines to be out-of-place! This PR also adds some general guidelines for what to do if an in-place operation is vmap-incompatible. General guidelines ------------------ If an in-place operation used in a backward formula is vmap-incompatible, then as developers we have the following options: - If the in-place operation directly followed the creation of a tensor with a factory function like at::zeros(...), we should replace the factory with a corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call propagates the batch dims to the resulting tensor. For example: Before: at::zeros(input.sizes(), grad.options()).copy_(grad) After: grad.new_zeros(input.sizes()).copy_(grad) - If the in-place operation followed some sequence of operations, if the we want to be able to vmap over the backward formula as-is (this is usually the case for simple (<15loc) backward formulas), then use inplace_is_vmap_compatible to guard the operation. For example: c = a * b Before: c.mul_(grad) After: c = inplace_is_vmap_compatible(c, grad) ? c.mul_(grad) : c * grad - If we don't want to vmap directly over the backward formula (e.g., if the backward formula is too complicated or has a lot of vmap-incompatible operations, then register the backward formula as an operator and eventually write a batching rule for it. Test Plan --------- New tests Test Plan: Imported from OSS Reviewed By: zhangguanheng66 Differential Revision: D25069525 Pulled By: zou3519 fbshipit-source-id: e0dfeb5a812f35b7579fc6ecf7252bf31ce0d790
Author
Parents
Loading