Add functorch TLS to ATen/ThreadLocalState (#69181)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69181
functorch lives out-of-tree. However, it has some TLS that needs to be
propagated. The solution for that is we store a pointer to the TLS
inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
include whatever functorch needs.
A previous solution used ThreadLocalDebugInfo. However, all
PyTorch-managed threads (e.g. spawned by Autograd) all receive a
shared_ptr that points to the same ThreadLocalDebugInfo. This leads to
race conditions if the multiple threads start modifying the TLS
stored within ThreadLocalDebugInfo without using mutexes.
Test Plan:
- tested with functorch
- The performance impact of this change when functorch is not used is
negligible because we end up manipulating nullptrs.
Reviewed By: albanD
Differential Revision: D32742312
Pulled By: zou3519
fbshipit-source-id: 1a8439a4af06b3d3e50b9a2dbca98a0ba612062a