[Pallas] Remove torch.empty in tracing (#6897)
Summary:
Previously we rely on torch.empty to create some empty tensors as the outputs from the Pallas and then make Pallas as in-place ops. However, it turns out that torch.empty will actually allocate memory and therefore it's expansive to use. In this change, I switched to simply pass the shapes and dtypes to construct the graph node.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py
Performance benchmarks can be found: http://shortn/_wdmom7I6q7