pytorch
8619d263 - Add batching rule for torch.expand (#40097)

Commit
4 years ago
Add batching rule for torch.expand (#40097) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40097 This is (probably) necessary for the vmap frontend API (coming up after this PR should be the vmap frontend API). There is some manual handling of sizes in the `expand_batching_rule`. In particular, when performing expand(Tensor[B0, 3], [2, 3]), where B0 is a batch dimension and Tensor[B0, 3] is a batched tensor with batch dimension B0, we can't call expand directly on the physical view and instead first need to perform a view. It's possible to add said view as a helper function on `VmapPhysicalView` but after reading through the operator spreadsheet the conclusion was that no other operator needs the same manual handling. Test Plan: - `./build/bin/vmap_test` Differential Revision: D22070657 Pulled By: zou3519 fbshipit-source-id: 911854b078a1a5c7d5934ef2e17b16673ed9d103
Author
Parents
Loading