pytorch
752d496c - Fix `broadcast_in_dim` support in NVFuser Frontend (#76790)

Commit
3 years ago
Fix `broadcast_in_dim` support in NVFuser Frontend (#76790) This PR primarily addresses augmenting the frontend to properly support `broadcast_in_dim`. This required make a new version of the `define_tensor()` that takes in the `size` and `strides` of input tensors in order to properly determine broadcasts. This PR also has a fix for the `python_example.py` that broke when a new argument was added to reductions to allow the user to specify an output Data Type. `define_tensor()` Interface Example: ``` fusion2 = Fusion() input1 = torch.ones(1, 1, 4, device='cuda') input2 = torch.ones(2, 3, 4, device='cuda') with FusionDefinition(fusion2) as fd : t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride()) t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride()) fd.add_input(t0) fd.add_input(t1) t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2]) print("Broadcast TensorView", t0_b) t2 = fd.Ops.add(t0_b, t1) fd.add_output(t2) ``` Print statement of defined broadcast tensor: ``` Broadcast TensorView T2_l[ sbS6{1}, sbS7{1}, iS8{i2} ] DataType: float Contiguity: ttt ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/76790 Approved by: https://github.com/mruberry, https://github.com/jjsjann123
Author
Committer
Parents
Loading