Use `pytree.tree_leaves` everywhere (#112324)
Summary:
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.
X-link: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
Reviewed By: ZainRizvi
Differential Revision: D50819663
fbshipit-source-id: 110cbd1295a752fb8b73fbd71009b5823a2cfa86