pytorch
e40a5630 - Fix sum batching rule, add simple clone batching rule (#47189)

Commit
4 years ago
Fix sum batching rule, add simple clone batching rule (#47189) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47189 PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail and instead returns a new copy of the original scalar_tensor. If we end up vmapping over per-example scalar tensors, e.g., ``` >>> x = torch.randn(B0) # the per-examples are all scalars >>> vmap(partial(torch.sum, dim=0), x) ``` then we should replicate the behavior of sum(scalar_tensor, dim=0) by returning a clone of the input tensor. This PR also adds a batching rule for clone(Tensor, MemoryFormat). The batching rule: - unwraps the BatchedTensor, calls clone(), and rewraps the BatchedTensor if MemoryFormat is torch.preserve_format (which is the default). - errors out with an NYI for all other memory formats, including torch.contiguous_format. There are some weird semantics for memory layouts with vmap that I need to go and figure out. Those are noted in the comments for `clone_batching_rule` Test Plan: - new tests Reviewed By: ejguan Differential Revision: D24741689 Pulled By: zou3519 fbshipit-source-id: e640344b4e4aa8c0d2dbacc5c49901f4c33c6613
Author
Parents
Loading