pytorch
fe1968de - [primTorch] Prototype nvFuser integration and test_prims.py

Commit
4 years ago
[primTorch] Prototype nvFuser integration and test_prims.py This adds prototype nvFuser integration for the following prims: - broadcast_in_dim - convert_element_type - add - div - ge - gt - le - lt - mul Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy. This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example: ``` def foo(a, b): return torch.add(a, b) traced_foo = make_traced(foo) a = torch.randn((1, 2, 3, 4, 5), device='cuda') b = torch.randn((1, 2, 3, 4, 5), device='cuda') result = traced_foo(a, b, executor='nvfuser') ``` Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations. Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560 Approved by: https://github.com/ngimel
Author
Mike Ruberry
Committer
Parents
Loading