Replace .new_zeros(()) with 0.0 in torch/_decomp/decompositions (#83734)
`new_zeros` is decomposed into `prims.empty_strided`+`prims.fill`+`prims.copy_to` and none of these are supported by prims+nvFuser executor currently.
Replacing it with 0.0 makes these backward decompositions nvFuser friendly.
Example with `torch.ops.aten.hardsigmoid_backward.default`:
```py
# Before this PR
opcode name target args kwargs
------------- ------------------------ -------------------------------- ------------------------------------------------------------ ----------------------------------------------------------------------------------------
placeholder a_1 a_1 () {}
placeholder g_1 g_1 () {}
call_function gt_default nvprims.gt.default (a_1, -3.0) {}
call_function lt_default nvprims.lt.default (a_1, 3.0) {}
call_function bitwise_and_default nvprims.bitwise_and.default (gt_default, lt_default) {}
call_function mul_default nvprims.mul.default (g_1, 0.16666666666666666) {}
call_function empty_strided prims.empty_strided.default ([], []) {'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}
call_function fill_default prims.fill.default (empty_strided, 0) {}
call_function copy_to_default prims.copy_to.default (empty_strided, fill_default) {}
call_function broadcast_in_dim_default nvprims.broadcast_in_dim.default (copy_to_default, [3, 2], []) {}
call_function where_default nvprims.where.default (bitwise_and_default, mul_default, broadcast_in_dim_default) {}
output output output (where_default,) {}
# After this PR
opcode name target args kwargs
------------- ------------------- --------------------------- --------------------------------------- --------
placeholder a_1 a_1 () {}
placeholder g_1 g_1 () {}
call_function gt_default nvprims.gt.default (a_1, -3.0) {}
call_function lt_default nvprims.lt.default (a_1, 3.0) {}
call_function bitwise_and_default nvprims.bitwise_and.default (gt_default, lt_default) {}
call_function mul_default nvprims.mul.default (g_1, 0.16666666666666666) {}
call_function where_default nvprims.where.default (bitwise_and_default, mul_default, 0.0) {}
output output output (where_default,) {}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83734
Approved by: https://github.com/Chillee