Implement `_broadcast_to_and_flatten(pytree, spec)` (#46288)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46288
This "broadcasts" `pytree` to have the same structure as `spec`
and then flattens it.
I find it hard to describe what that does in words, so here's an example:
- Broadcasting 1 to have the same structure as [0, [0, 0]] would
return [1, [1, 1]]. Further flattening it gives us [1, 1, 1].
- Broadcasting [1, 2] to have the same structure as [0, [0, 0]] would
return [1, [2, 2]]. Further flattening it gives us [1, 2, 2].
What is this used for?
----------------------
The next PR up in the stack uses this helper function to allow vmap to
accept nested data structures. `vmap(fn, in_dims)(*inputs)` allows the
user to specify in_dims with a tree structure that is a sub-graph of
that of `inputs` (where both contain the root of the tree).
For example, one can do `vmap(fn, in_dims=0)(x, y, z)`. `in_dims` is 0
and inputs is (x, y, z). We would like to broadcast in_dims up to the
structure of inputs to get (0, 0, 0).
Another example, is `vmap(fn, in_dims=(0, 1))(x, [y, z])`. `in_dims` is
(0, 1) and inputs is (x, [y, z]). We would like to broadcast in_dims up
to the structure of inputs to get (0, [1, 1]); this value of in_dims is
used to say "let's vmap over dim 0 for x and dim 1 for y and z".
Test Plan
---------
New tests.
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D24392891
Pulled By: zou3519
fbshipit-source-id: 6f494d8b6359582f1b4ab6b8dd6a956d8bfe8ed4