xla
66ed39ba - [Pallas] Remove torch.empty in tracing (#6897)

Commit
1 year ago
[Pallas] Remove torch.empty in tracing (#6897) Summary: Previously we rely on torch.empty to create some empty tensors as the outputs from the Pallas and then make Pallas as in-place ops. However, it turns out that torch.empty will actually allocate memory and therefore it's expansive to use. In this change, I switched to simply pass the shapes and dtypes to construct the graph node. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py Performance benchmarks can be found: http://shortn/_wdmom7I6q7
Author
Parents
Loading