pytorch
473b733b - Replace .new_zeros(()) with 0.0 in torch/_decomp/decompositions (#83734)

Commit
2 years ago
Replace .new_zeros(()) with 0.0 in torch/_decomp/decompositions (#83734) `new_zeros` is decomposed into `prims.empty_strided`+`prims.fill`+`prims.copy_to` and none of these are supported by prims+nvFuser executor currently. Replacing it with 0.0 makes these backward decompositions nvFuser friendly. Example with `torch.ops.aten.hardsigmoid_backward.default`: ```py # Before this PR opcode name target args kwargs ------------- ------------------------ -------------------------------- ------------------------------------------------------------ ---------------------------------------------------------------------------------------- placeholder a_1 a_1 () {} placeholder g_1 g_1 () {} call_function gt_default nvprims.gt.default (a_1, -3.0) {} call_function lt_default nvprims.lt.default (a_1, 3.0) {} call_function bitwise_and_default nvprims.bitwise_and.default (gt_default, lt_default) {} call_function mul_default nvprims.mul.default (g_1, 0.16666666666666666) {} call_function empty_strided prims.empty_strided.default ([], []) {'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False} call_function fill_default prims.fill.default (empty_strided, 0) {} call_function copy_to_default prims.copy_to.default (empty_strided, fill_default) {} call_function broadcast_in_dim_default nvprims.broadcast_in_dim.default (copy_to_default, [3, 2], []) {} call_function where_default nvprims.where.default (bitwise_and_default, mul_default, broadcast_in_dim_default) {} output output output (where_default,) {} # After this PR opcode name target args kwargs ------------- ------------------- --------------------------- --------------------------------------- -------- placeholder a_1 a_1 () {} placeholder g_1 g_1 () {} call_function gt_default nvprims.gt.default (a_1, -3.0) {} call_function lt_default nvprims.lt.default (a_1, 3.0) {} call_function bitwise_and_default nvprims.bitwise_and.default (gt_default, lt_default) {} call_function mul_default nvprims.mul.default (g_1, 0.16666666666666666) {} call_function where_default nvprims.where.default (bitwise_and_default, mul_default, 0.0) {} output output output (where_default,) {} Pull Request resolved: https://github.com/pytorch/pytorch/pull/83734 Approved by: https://github.com/Chillee
Author
Committer
Parents
Loading