pytorch
f2eed940 - Register PrimTorch refs as decompositions.

Commit
2 years ago
Register PrimTorch refs as decompositions. For the most part, PrimTorch refs have the same signature as their ATen equivalents. I modify most PrimTorch refs to register themselves as decompositions, using the prim name they wrap to find the aten name (except for a few cases where the prim/aten names mismatch). There are some exclusions, falling into one of two categories: - The torch equivalent was already implemented as a CompositeImplicitAutograd decomposition in C++ - The ref doesn't support enough features (e.g., the real deal has more kwargs / overloads than are currently implemented) PrimTorch refs are written as a single function that supports all overloads, and this style is convenient for cases where we have a bundle of overloads for what morally is a single overload with a Union type on an argument (which we ought to have supported in native_functions.yaml but blah); to support registering a single decomp for all the overloads, we modify register_decomposition to register to ALL overloads if you pass it an overload packet. This is technically BC breaking but no tests started failing because of it. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/76835 Approved by: https://github.com/Chillee, https://github.com/mruberry
Author
Committer
Parents
Loading