benchmark
a31c3fe9 - Revert "Trace enter/exit of TorchFunctionModes (#135422)" (#136590)

Commit
1 year ago
Revert "Trace enter/exit of TorchFunctionModes (#135422)" (#136590) Summary: This reverts commit 7743149b2be4a9eba7e0997ccdc6abe552bec266. Reverts * https://github.com/pytorch/pytorch/pull/135503 * https://github.com/pytorch/pytorch/pull/135502 * https://github.com/pytorch/pytorch/pull/135422 This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation. ``` import torch from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention from torch._inductor.lowering import make_pointwise, register_lowering from torch._inductor.virtualized import ops from torch.nn.attention.flex_attention import create_block_mask torch.set_default_device('cuda') flex_attention = torch.compile(flex_attention, dynamic=False) prefix_lengths = torch.arange(8) def prefix_lm(b, h, q, kv): return prefix_lengths[b] >= kv mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True) ``` X-link: https://github.com/pytorch/pytorch/pull/136590 Approved by: https://github.com/Chillee Reviewed By: atalman Differential Revision: D63431470 Pulled By: anijain2305 fbshipit-source-id: 60915b30336121b845af71f423582c22a6c65c3f
Author
Parents
Loading