Functionalization of torch.rand/rand_like ops (#97377)
This PR introduces the functionalization of RNG ops. Key points are
* Introduces a new `philox_rand` prim operator that accepts seed, offset.
* Adds decompositions for random operators that use these philox_rand prims
* Adds a PhiloxStateTracker to track the offset for each occurence of rand ops
* Changes calling convention of AOT Autograd and adds <fwd_seed, fwd_base_offset> and <bwd_seed, bwd_base_offset>
* Monkeypatches set_rng_state and get_rng_state while AOT Autograd tracing to record the rng state behavior
* Raises assertion for CPU because CPU does not Philox RNG.
Not dealt in this PR
* dropout op - offset calculation is different
* other distributions like normal, poisson etc
* Inductor support
* Cudagraph support
* Dynamic shape support
An example
~~~
class Custom(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
====== Forward graph 0 ======
def forward(self, fwd_seed_1: i64[], fwd_base_offset_1: i64[], primals_1: f32[16, 16]):
# No stacktrace found for following nodes
add: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 0)
philox_rand: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add, [16, 1], device(type='cuda', index=0), torch.float32); add = None
mul: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand, primals_1); philox_rand = None
add_1: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 4); fwd_base_offset_1 = None
philox_rand_1: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add_1, [16, 1], device(type='cuda', index=0), torch.float32); fwd_seed_1 = add_1 = None
mul_1: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand_1, mul); philox_rand_1 = mul = None
return [mul_1, primals_1]
====== Backward graph 0 ======
def forward(self, bwd_seed_1: i64[], bwd_base_offset_1: i64[], primals_1: f32[16, 16], tangents_1: f32[16, 16]):
# No stacktrace found for following nodes
add_2: i64[] = torch.ops.aten.add.Tensor(bwd_base_offset_1, 0); bwd_base_offset_1 = None
philox_rand_2: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], bwd_seed_1, add_2, [16, 1], device(type='cuda', index=0), torch.float32); bwd_seed_1 = add_2 = None
mul_2: f32[16, 16] = torch.ops.aten.mul.Tensor(tangents_1, philox_rand_2); tangents_1 = philox_rand_2 = None
cos: f32[16, 16] = torch.ops.aten.cos.default(primals_1); primals_1 = None
mul_3: f32[16, 16] = torch.ops.aten.mul.Tensor(mul_2, cos); mul_2 = cos = None
return [mul_3]
~~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97377
Approved by: https://github.com/ezyang