[philox_rand] Dynamic shape support (#99290)
Extends the functionalization of rng work to Dynamic shapes. An example of the generated graph looks like this
~~~
[2023-04-24 21:41:37,446] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
===== Forward graph 1 =====
<eval_with_key>.7 class <lambda>(torch.nn.Module):
def forward(self, arg0_1: i64[], arg1_1: i64[], arg2_1: Sym(s0), arg3_1: Sym(s1), arg4_1: f32[s0, s1]):
# File: /scratch/anijain/work/pytorch/test/test_functionalization_of_rng_ops.py:46, code: a = torch.rand_like(x) * x
add: i64[] = torch.ops.aten.add.Tensor(arg1_1, 0)
philox_rand = torch.ops.rngprims.philox_rand.default([arg2_1, arg3_1], arg0_1, add, None, device(type='cuda', index=0), torch.float32); add = None
getitem: f32[s0, s1] = philox_rand[0]
getitem_1: i64[] = philox_rand[1]; philox_rand = None
add_1: i64[] = torch.ops.aten.add.Tensor(getitem_1, 0); getitem_1 = None
mul: f32[s0, s1] = torch.ops.aten.mul.Tensor(getitem, arg4_1); getitem = arg4_1 = None
# File: /scratch/anijain/work/pytorch/test/test_functionalization_of_rng_ops.py:47, code: a = torch.rand_like(x) * a
add_2: i64[] = torch.ops.aten.add.Tensor(arg1_1, add_1)
philox_rand_1 = torch.ops.rngprims.philox_rand.default([arg2_1, arg3_1], arg0_1, add_2, None, device(type='cuda', index=0), torch.float32); arg2_1 = arg3_1 = arg0_1 = add_2 = None
getitem_2: f32[s0, s1] = philox_rand_1[0]
getitem_3: i64[] = philox_rand_1[1]; philox_rand_1 = None
add_3: i64[] = torch.ops.aten.add.Tensor(add_1, getitem_3); add_1 = getitem_3 = None
mul_1: f32[s0, s1] = torch.ops.aten.mul.Tensor(getitem_2, mul); getitem_2 = mul = None
# No stacktrace found for following nodes
add_4: i64[] = torch.ops.aten.add.Tensor(arg1_1, add_3); arg1_1 = add_3 = None
return (mul_1, add_4)
~~~
Each rand op is accompanied by its offset calculation op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99290
Approved by: https://github.com/ezyang, https://github.com/bdhirsh