Fixed hash issue in `fx_graph_cse` (#119567)
Description:
- Fixed issue with hash collision for `hash((primals_2, 1.0)) == hash((primals_2, 1))`
Repro code:
```python
import torch
from torch._functorch.compile_utils import fx_graph_cse
def func(inpt, osize):
size = inpt.shape[-1]
s1 = size - 1
s2 = size - 1.0
scale = s2 / (osize - 1.0)
inpt = torch.clamp(inpt, 0, s1)
return scale * inpt
gms = []
def toy_backend(gm, _):
gms.append(gm)
return gm.forward
torch._dynamo.reset()
fn = torch.compile(backend=toy_backend, dynamic=True)(func)
t = torch.rand(3, 100)
out = fn(t, 50)
gm = gms[0]
print(gm.graph)
new_fx_g = fx_graph_cse(gm.graph)
print(str(new_fx_g))
```
Original graph
```
graph():
%s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
%s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
%l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
%l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
%size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
%sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1.0), kwargs = {})
%sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
%truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub_1, %sub_2), kwargs = {})
%inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
%mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
return (mul,)
```
New wrong graph where `sub_2` is replaced incorrectly with `sub`:
```
graph():
%s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
%s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
%l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
%l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
%size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
%sub : [num_users=2] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
%sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
%truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub, %sub_2), kwargs = {})
%inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
%mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
return (mul,)
```
With this PR the new graph is the following:
```
graph():
%s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
%s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
%l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
%l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
%size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
%sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1.0), kwargs = {})
%sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
%truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub_1, %sub_2), kwargs = {})
%inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
%mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
return (mul,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119567
Approved by: https://github.com/eellison