pytorch
2c43876f - AT_DISPATCH: Expose switch-case like macro syntax (#79978)

Commit
2 years ago
AT_DISPATCH: Expose switch-case like macro syntax (#79978) This expands the `AT_DISPATCH` macros to enable writing your own `AT_DISPATCH_SWITCH` statements with multiple `AT_DISPATCH_CASE` labels. So, where previously you may have written: ```cpp if (iter.common_dtype() == kBool) { my_bool_kernel(iter); } else { AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "my_kernel", [&] { ... }); } ``` You can now instead write ```cpp AT_DISPATCH_SWITCH(iter.common_dtype(), "my_kernel", AT_DISPATCH_CASE(kBool, [&] { my_kernel_bool(iter); }) AT_DISPATCH_CASE_INTEGRAL_TYPES([&] { ... }) ); ``` The macro syntax is a bit ugly, however the benefits are: - Greater flexibility, as the kernel code doesn't have to be shared for all dtypes. - Selective build and RECORD_KERNEL_FUNCTION work even for single dtype specializations such as the bool case in the example. - The compiler sees a single switch for all types, which should be easier to optimize into a jump table. - We also now get errors if the same scalar type is handled twice. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79978 Approved by: https://github.com/ezyang
Author
Committer
Parents
Loading