pytorch
f2a38a0e - Enabled BFloat16 support for argmax & argmin on both CPU & CUDA (#52582)

Commit
3 years ago
Enabled BFloat16 support for argmax & argmin on both CPU & CUDA (#52582) Summary: 1. Enabled `BFloat16` support for `argmax` & `argmin` on both CPU & CUDA 2. Added `OpInfo`s for `argmax` & `argmin` 3. Enabled `test_argminmax_multiple` for `float16`. It can't be enabled for `bfloat16`, as comparison is done with numpy, which doesn't currently support `bfloat16`. 4. Enabled `test_dim_arg_reduction_scalar` for `float16` & `bfloat16`. 5. Enabled `test_reduction_vectorize_along_output` for `bfloat16`. 6. Enabled `test_reduction_vectorize_along_input_corner` for `bfloat16`. 7. Enabled `test_dim_reduction` for both `float16` and `bfloat16`, except that both of them don't support `prod` on CPU. 8. Unskipped `TestCommonCPU.test_variant_consistency_jit` for dtype `bfloat16` for `amax` & `amin`, as they're passing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/52582 Reviewed By: anjali411 Differential Revision: D27204704 Pulled By: heitorschueroff fbshipit-source-id: cdad5df494d070f8e1a8fb83939441a91124b4d9
Author
sanchit
Parents
Loading