move dynamo MetricsContext into TLS (#170605)
Summary:
Fixes test_functional_differentials.py
- test_all_reduce_compile
- test_all_gather_tensor_compile
- test_reduce_scatter_tensor_compile
- test_all_to_all_single_compile
The test class is MultiThreadedTestCase that uses the "threaded" PG. This exposes that dynamo's MetricsContext is not thread safe when all ranks attempt to use torch.compile().
```
Previous Traceback:
File "/opt/conda/envs/py_3.10/lib/python3.10/threading.py", line 973, in _bootstrap
self._bootstrap_inner()
File "/opt/conda/envs/py_3.10/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/opt/conda/envs/py_3.10/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 1410, in _run
self.run_test_with_threaded_pg(test_name, rank, world_size)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 1425, in run_test_with_threaded_pg
getattr(self, test_name)()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 1324, in wrapper
fn()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 3329, in wrapper
method(*args, **kwargs)
File "/var/lib/jenkins/pytorch/test/distributed/test_functional_differentials.py", line 607, in test_all_gather_tensor_compile
loss.backward()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 631, in backward
torch.autograd.backward(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 365, in backward
_engine_run_backward(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 317, in apply
return user_fn(self, *args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2338, in backward
return impl_fn()
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2324, in impl_fn
out = CompiledFunction._backward_impl(ctx, all_args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2463, in _backward_impl
CompileEventLogger.compilation_metric(is_forward=False)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 617, in compilation_metric
CompileEventLogger.add_toplevel(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 485, in add_toplevel
CompileEventLogger.add_data(top_event, log_level, overwrite, **metadata)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 470, in add_data
metrics_context.update(metadata, overwrite)
RuntimeError: Metric(s) {'is_forward'} have already been set in the current context. (see above for current and previous traceback).
```
X-link: https://github.com/pytorch/pytorch/pull/170605
Approved by: https://github.com/jansel, https://github.com/mlazos
Reviewed By: seemethere
Differential Revision: D89508347
fbshipit-source-id: 9a7344573d8922c79b0141b24eba9e3becb4fbe4
Author
generatedunixname499836121