Initial prims, references, and test architecture for them (#75095)
Summary:
This PR adds an initial set of experimental primitive operations and Python references that reimplement existing PyTorch operations using them. See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 for additional context.
The following experimental primitives are added:
- Elementwise unary prims -- abs, acos, acosh, asin, atan, cos, cosh, bessel_i0e, bessel_i1e, cbrt, ceil, digamma, erf, erf_inv, erfc, exp, expm1, floor, igamma, igammac, is_finite, lgamma, log, log1p, neg, reciprocal, round, sign, sinh, sqrt, square, tan.
- Elementwise binary prims -- add, atan2, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, div, eq, ge, gt, le, lt, max, min, mul, ne, nextafter, pow, rsqrt, shift_left, shift_right_arithmetic
- View prims -- brodcast_in_dim, collapse_view, split_dim, squeeze
- Shape prims -- collapse, concatenate, reshape
- Conditional prims -- select
- Data conversion & movement prims -- convert_element_type, device_put
- Inplace prims -- copy_to, resize
These primitives do not add any new functionality to PyTorch, but are intended to be the semantic building blocks for reference operators. We have tried to make them consistent with the operations in [jax.lax](https://jax.readthedocs.io/en/latest/jax.lax.html) where possible (because PyTorch prefers being consistent with other frameworks), although there are key differences between these prims and operations in jax.lax. Most notably is that these prims model view semantics and inplace operations.
In addition to these primitives the following elementwise binary Python references are added:
- Elementwise binary Python references -- add, atan2, bitwise_and, bitwise_left_shift, bitwise_or, bitwise_right_shift, bitwise_xor, eq, float_power, ge, gt, le, lt, maximum, minimum, mul, ne, nextafter, pow, sub, true_divide
- Conditional Python references - where
- Data conversion & movement references - copy_to
A Python reference implements the same behavior as its corresponding PyTorch operator (excepting slight numerical differences, bug fixes, and in some cases additional features).
The start of an OpInfo-based test architecture for these references is also included in this PR. A new list, `python_ref_db`, is added to `common_methods_invocations.py`. This list introduces the new `ElementwiseBinaryPythonRefInfo`, which inherits input arguments from the original operators' OpInfo, allows them to be overridden, and then constructs the OpInfo for the Python reference using the (potentially modified) arguments. OpInfo-based tests can opt-into testing references by including this new list in the Sequence passed to the `ops` decorator.
cc ngimel csarofeen kevinstephano Lezcano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75095
Reviewed By: ngimel
Differential Revision: D35888004
Pulled By: mruberry
fbshipit-source-id: 21e77c4456c2a02113367d4bdae168a3a2f33f25
(cherry picked from commit 1d5bcfa99d4e8cf36f60642803a0bfca50e2ea4e)