pytorch
b3071e2e - functionalization: skip meta reference compute for aot autograd (#87108)

Commit
2 years ago
functionalization: skip meta reference compute for aot autograd (#87108) The context is that historically, XLA/LTC tensors haven't had accurate stride information, and functionalization would run "reference" meta kernels for view ops on the side to properly compute strides. This is more complicated in symint tracing world - we have a `FunctionalTensorWrapper()` that wraps the underlying tensor and has its own set of sizes/strides metadata, but we never create proxy objects for the sizes/strides of the wrapper. In symint tracing world with aot autograd, we're guaranteed that our underlying strides are accurate anyway, since aot autograd uses fake tensors to perform tracing. We encountered a few bugs with symint's from the `FunctionalTensorWrapper` making their way into `__torch_dispatch__`. To side-step that area of bugs completely (and marginally improve perf), this PR disables the meta tensor tracing for non XLA/LTC use cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87108 Approved by: https://github.com/ezyang, https://github.com/wconstab
Author
Committer
Parents
Loading