[functorch] dont compute expected output multiple times (#86202)
Fixes https://github.com/pytorch/functorch/issues/1028
Description: We update `get_fallback_and_vmap_exhaustive` to compute expected output only once as described in the issue.
NOTE: This doesn't take care of the repeated computation in `test_vmap_exhaustive` and will be followed up later.
TODO:
* [x] Benchmark and see how much difference does this make. (Comparison Table Below: [Link](https://github.com/pytorch/pytorch/pull/86202#issuecomment-1285477653))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86202
Approved by: https://github.com/zou3519