Ensure thread id is valid in nested parallel regions (#60183)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60183
Fixes https://github.com/pytorch/pytorch/pull/59149#issuecomment-863287331
`parallel_for` will call the function directly if it would have run on only a
single thread anyway. This is great for performance, but causes an issue in
nested parallel regions because `get_thread_num` will reflect the parent
parallel region instead of the current `parallel_for` call.
I fix this by using a `thread_local` variable for the current thread id and
manually setting it before each call to the user-provided function.
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D29287816
Pulled By: ngimel
fbshipit-source-id: 777f771a0900750c7f22eb1dd185d84d19282108