Make wrapPropagateTLSState more generic (#57634)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57634
`wrapPropagateTLSState` was restricting its argument to be an argument-less function, and I need to relax this for later work.
Also, it was requiring its argument to be converted to `std::function`, and also returned a `std::function`. Each creation of a `std::function` could cause a heap allocation. It's not particularly expensive, but here we can easily avoid it by having `wrapPropagateTLSState` directly operate on generic callables (thus, possibly, raw lambdas).
ghstack-source-id: 128295264
Test Plan: CI
Reviewed By: ilia-cher
Differential Revision: D28178782
fbshipit-source-id: d657f5751514974518606dd4fc4175e805dcb90a