[JIT] Ensure offset is a multiple of 4 to fix "Philox" RNG in jitted kernels (#50169)
Summary:
Immediately-upstreamable part of https://github.com/pytorch/pytorch/pull/50148.
This PR fixes what I'm fairly sure is a subtle bug with custom `Philox` class usage in jitted kernels. `Philox` [constructors in kernels](https://github.com/pytorch/pytorch/blob/68a6e4637903dba279c60daae5cff24e191ff9b4/torch/csrc/jit/codegen/cuda/codegen.cpp#L102) take the cuda rng generator's current offset. The Philox constructor then carries out [`offset/4`](https://github.com/pytorch/pytorch/blob/74c055b24065d0202aecdf4bc837d3698d1639e1/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu#L13) (a uint64_t division) to compute its internal offset in its virtual Philox bitstream of 128-bit chunks. In other words, it assumes the incoming offset is a multiple of 4. But (in current code) that's not guaranteed. For example, the increments used by [these eager kernels](https://github.com/pytorch/pytorch/blob/74c055b24065d0202aecdf4bc837d3698d1639e1/aten/src/ATen/native/cuda/Distributions.cu#L171-L216) could easily make offset not divisible by 4.
I figured the easiest fix was to round all incoming increments up to the nearest multiple of 4 in CUDAGeneratorImpl itself.
Another option would be to round the current offset up to the next multiple of 4 at the jit point of use. But that would be a jit-specific offset jump, so jit rng kernels wouldn't have a prayer of being bitwise accurate with eager rng kernels that used non-multiple-of-4 offsets. Restricting the offset to multiples of 4 for everyone at least gives jit rng the chance to match eager rng. (Of course, there are still many other ways the numerics could diverge, like if a jit kernel launches a different number of threads than an eager kernel, or assigns threads to data elements differently.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50169
Reviewed By: mruberry
Differential Revision: D25857934
Pulled By: ngimel
fbshipit-source-id: 43a75e2d0c8565651b0f12a5694c744fd86ece99