pytorch
31750930 - [primTorch] Adds random operations (#78026)

Commit
2 years ago
[primTorch] Adds random operations (#78026) This PR... **Issues Found** - https://github.com/pytorch/pytorch/issues/78058 - https://github.com/pytorch/pytorch/issues/78054 - https://github.com/pytorch/pytorch/issues/78053 - https://github.com/pytorch/pytorch/issues/78050 - https://github.com/pytorch/pytorch/issues/77932 **Testing** - disables stride consistency checks in test_ops and test_meta pending resolution of https://github.com/pytorch/pytorch/issues/78050 - skips chalf in reference tests (addressing https://github.com/pytorch/pytorch/issues/78054) - splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo - updates test names to be more natural and consistent: - test_python_reference_errors -> test_python_ref_errors - test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback - test_python_reference_meta_functions -> test_python_ref_meta - test_reference_testing -> test_numpy_ref - updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing https://github.com/pytorch/pytorch/issues/77687) - adds reference inputs for broadcast_tensors - Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator - Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly - Adds reference inputs for elementwise ternary operations, like clamp - Adds a NumPy reference for clamp - Adds reference inputs to where's OpInfo - Makes softplus an elementwise unary OpInfo - Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes - Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where **Prims** - adds the fill, empty_strided, and uniform prims - removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill - renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch - extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where` - fixes a bug in the stride propagation of broadcast_in_dim - moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible **Utils** - adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers - adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...) **Refs** - adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references - adds the nn.functional.dropout reference - fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode Pull Request resolved: https://github.com/pytorch/pytorch/pull/78026 Approved by: https://github.com/ngimel
Author
Mike Ruberry
Committer
Parents
Loading