Note: Links to docs will display an error until the docs builds have been completed.
As of commit d0f5719 with merge base 46a35a1 ():
๐ Looks good so far! There are no failures yet. ๐
This comment was automatically generated by Dr. CI and updates every 15 minutes.
TensorShape change doesn't look sound
Hmm, some of the CI failures look like needing to suppress guards when creating functional tensor, use shape_env.suppress_guards
Hmm, some of the CI failures look like needing to suppress guards when creating functional tensor, use
shape_env.suppress_guards
Pure guess here based on your suggestion. There is a call to torch._to_functional_tensor
in meta_utils.py. I think you are suggesting I guard that call as follows:
with maybe_suppress():
r = torch._to_functional_tensor(unwrapped)
??
Hmm, some of the CI failures look like needing to suppress guards when creating functional tensor, use
shape_env.suppress_guards
Hmm. Executing on this suggestion does not appear straightforward to me. Here is where I am at the moment:
to_functional
and from an instance method wrap
in FunctionalTensorMode.shape_env
available in either of these calling contexts. This makes me think that some plumbing work is required to get a previously constructed shape_env
into these contexts so that shape_env.suppress_guards
can be invoked.From WC: "Im not at computer but my gut idea is to suppress in aot autograd code"
To elaborate further, you want to get the stack trace that produced the guard in question. This can be done with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED
Added a couple of suppress_guard's in aot autograd code.
@ezyang : Thanks to your pointers, I made the changes to use suppress_guards
in aot_autgrad code. Most of those failures are gone now.
Now I see many failures of variations of this test:
python test/functorch/test_aotdispatch.py -k TestEagerFusionModuleInfoCPU.test_aot_autograd_symbolic_module_exhaustive_nn_CrossEntropyLoss_cpu_float32
These errors are all of this form:
RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides
The error is raised from C++ code here:
int64_t numel_default() const {
if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
throw_cannot_call_with_symbolic("numel");
}
return numel_;
}
It is strange that the change here which only supplies explicit storage size in the call to torch.Tensor._make_wrapper_subclass
could have this effect.
Any suggestions?
Continuing the debugging of
python test/functorch/test_aotdispatch.py -k TestEagerFusionModuleInfoCPU.test_aot_autograd_symbolic_module_exhaustive_nn_CrossEntropyLoss_cpu_float32
After looking at the C++ stack trace, here is what I have discovered:
The exception is being thrown inside cross_entropy_loss_symint
in file LossNLL.cpp. This function calls t.numel()
on a tensor t which has symbolic strides etc. inside it. The code does not like this and throws an exception. Interestingly, there is another method sym_numel()
which is capable of handling symbolic stuff apparently.
Wondering if the problem is that the code should be calling sym_numel() rather than numel().
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased add-checks-tr
onto refs/remotes/origin/viable/strict
, please pull locally before adding more changes (for example, via git checkout add-checks-tr && git pull --rebase
)
RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides
https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.9ozikwcc6g7b explains the error in more detail, but the short answer is somewhere there's a numel() call in C++ but it should be sym_numel(), so yes, your surmise is correct. Try changing it.
@ezyang : I got past the sym_numel()
business. Only two tests seems to be failing now. Consider this one:
export/test_export.py::TestExport::test_disable_forced_specializations_errors
When I read this test, it seems that the expectation is that a UserError exception must be thrown.
class Foo(torch.nn.Module):
def forward(self, w, x, y, z):
return w.reshape([-1]) + x, y + z # simple: s0*s1 = s2, s3 = s4
inputs = (
torch.randn(3, 4),
torch.randn(12),
torch.randn(4),
torch.randn(4),
)
dynamic_shapes = {
"w": [Dim(f"dw{i}") for i in range(2)],
"x": [Dim(f"dx{i}") for i in range(1)],
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
"z": [Dim("dz")], # suggested fix should be to match these up.
}
Please say a few words about why the expected output for this test is a Usererror.
@ezyang : Now consider the second test that is failing. This one seems like a bug in symbolic_shapes.py.
File "/data/users/shaz/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 5407, in _refine_ranges
vr = self.var_to_range[symbol]
KeyError: s0
At the line above, a symbol s0 is being looked up in var_to_range and not being found.
For the same symbol s0, earlier a warning is being emitted:
W0610 14:40:28.664000 140396975735808 torch/fx/experimental/symbolic_shapes.py:4449] s0 is not in var_ranges, defaulting to unknown range.
At the place where this warning is being emitted, I was expecting that the "unknown range" would get populated in var_to_range but that does not seem to be done. Could this be the root cause?
@ezyang: I was too optimistic, sigh ... In addition to the two unresolved errors mentioned above, there are two more bugs that my change is tickling.
python test/inductor/test_torchinductor_dynamic_shapes.py -k DynamicShapesCpuTests.test_tensor_index_slice_dynamic_shapes_cpu
tickles
E0610 16:42:03.541000 140211530417152 torch/_guards.py:262] [0/0] AssertionError: s0 (could be from ["L['a']._base.size()[0]"]) not in {s1: ["L['a'].size()[0]"], s2: ["L['a'].size()[1]"], s3: ["L['a'].size()[2]"], s4: ["L['a'].size()[3]"], s5: ["L['a'].size()[4]", "L['a'].stride()
[3]"], s0: []}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
Some comments suggest that we need to fix this issue.
Next,
python test/inductor/test_torchinductor_dynamic_shapes.py -k DynamicShapesCpuTests.test_views5_dynamic_shapes_cpu
tickles
File "/data/users/shaz/a/pytorch/torch/_inductor/ir.py", line 2246, in _dynamic_reshape_indexer
var2, size_new2 = stack_new.pop()
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fx_wrapper' raised:
LoweringException: IndexError: pop from empty list
target: aten.reshape.default
args[0]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cpu', torch.float32, size=[s0, s0, s0, s0], stride=[s0**3, s0**2, s0, 1]))
),
FixedLayout('cpu', torch.float32, size=[s0, s0 - 4, s0, s0], stride=[s0**3, s0**2, s0, 1], offset=4*s0**2),
origins={slice_2}
)
)
args[1]: [s0, -1, 4]
I wonder if this bug has also been reported already.
559 | 563 | else: | |
560 | 564 | ctx = nullcontext() | |
561 | with ctx: | ||
565 | with ctx, maybe_suppress(): |
I think the scope of the maybe_suppress
is too aggressive here. I only want to suppress when I'm creating FunctionalTensors; everything else I still want normal guards to be reported
691 | with maybe_suppress(): | ||
692 | compiled_fn, fw_metadata = compiler_fn( | ||
693 | flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata | ||
694 | ) |
This is especially unlikely to be correct, because Inductor can generate new guards and we really DO need to respect them, because they are necessary for correctness
I'd also emphasize the importance of having a test case in PyTorch itself for this problem, this is possibly more important than hosing down the rest of the failures
I reverted my change for suppressing guards expecting to see some failures in aot_autograd tests pop up again. But now I don't see them. Perhaps a rebase I did a bit earlier pulled in some other fixes that took care of these failures. Now the only errors left are the ones described in this comment
@pytorchbot rebase
@ezyang : I have isolated the sym_numel
change in this PR now. Can we land it now?
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased add-checks-tr
onto refs/remotes/origin/viable/strict
, please pull locally before adding more changes (for example, via git checkout add-checks-tr && git pull --rebase
)
@pytorchbot label "topic: not user facing"
@pytorchbot merge
Your change will be merged once all checks pass (ETA 0-4 Hours).
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX Team
Login to write a write a comment.
This PR replaces call to
numel
withsym_numel
in cross_entropy_loss_prob_target.