pytorch
5d70d128 - [dynamo] turn torch.backends.cudnn.is_acceptable into a constant (#90323)

Commit
2 years ago
[dynamo] turn torch.backends.cudnn.is_acceptable into a constant (#90323) Tracing `torch.backends.cudnn.is_acceptable(Tensor) -> bool:` fails with: ``` ... File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/functions.py", line 196, in call_function return super(UserFunctionVariable, self).call_function(tx, args, kwargs) File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/functions.py", line 67, in call_function return tx.inline_user_function_return( File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 426, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 1698, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 1752, in inline_call_ tracer.run() File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 485, in run and self.step() File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 455, in step getattr(self, inst.opname)(inst) File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 281, in wrapper return inner_fn(self, inst) File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 912, in CALL_FUNCTION self.call_function(fn, args, {}) File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 389, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/torch.py", line 431, in call_function tensor_variable = wrap_fx_proxy( File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/builder.py", line 662, in wrap_fx_proxy return wrap_fx_proxy_cls( File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/builder.py", line 820, in wrap_fx_proxy_cls raise AssertionError( AssertionError: torch.* op returned non-Tensor bool call_function <function is_acceptable at 0x7f00deefb790> ``` So instead, evaluate `is_acceptable()` and convert the result to a constant. The result of `is_acceptable(tensor) -> bool` depends on: * dtype/device of the input tensor (this should already be guarded) * properties of the build & whether cudnn is available * some global state that gets initialized during the first call to `torch.backends.cudnn._init()` (this is NOT guarded in this PR) Note: this fixes tts_angular with FSDP. This was an issue with FSDP because FSDP modules are interpreted as UnspecializedNNModules, and UnspecializedNNModules try to inline calls. In comparison, NNModules (e.g. when the tts_angular model is not wrapped in FSDP) do not inline calls and instead evaluate subsequent calls. In subsequent calls, cudnn.is_acceptable would be skipped by eval_frame.py:catch_errors because it is not in an allowlist. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90323 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading