pytorch
b1c85dd9 - Custom RNG DispatchKey (#32325)

Commit
4 years ago
Custom RNG DispatchKey (#32325) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32325 The purpose of this PR is to enable PyTorch dispatching on `at::Generator*` parameters and demonstrate how it can be used in cpp extensions to implement custom RNG. 1. `CustomRNGKeyId` value added to DispatchKey enum and `DispatchKeySet key_set_` added to `at::Generator` 2. The overloaded `operator()(at::Generator* gen)` added to MultiDispatchKeySet. 3. The existing CPUGenerator and CUDAGenerator class are supplied with CPUTensorId and CUDATensorId dispatch keys 4. The implementation of CPU's `cauchy_kernel`(as an example, because it's already moved to ATen) was templatized and moved to `ATen/native/cpu/DistributionTemplates.h` to make it available for cpp extensions 5. Minor CMake changes to make native/cpu tensors available for cpp extensions 6. RegisterCustomRNG test that demonstrates how CustomCPUGenerator class can be implemented and how custom_rng_cauchy_ native function can be registered to handle Tensor::cauchy_ calls. Test Plan: Imported from OSS Differential Revision: D19604558 Pulled By: pbelevich fbshipit-source-id: 2619f14076cee5742094a0be832d8530bba72728
Author
Parents
Loading