Apply same 'pick_grad' on generating fp64 reference outputs (#111593)
Summary:
To lower memory consumption for inference mode.
X-link: https://github.com/pytorch/pytorch/pull/111593
Approved by: https://github.com/msaroufim, https://github.com/thiagocrepaldi
ghstack dependencies: #111867
Reviewed By: izaitsevfb
Differential Revision: D50626494
fbshipit-source-id: c36dc14493c1013ef87c12001c23bb454396eaff