[FSDP] Adopts XLA_DISABLE_FUNCTIONALIZATION (#4806)
Summary:
Functionalization introduces a huge memory regression around 59% for GPT-2 with FSDP. The regression comes from two parts:
1. one is just introduced by funtionalization mechanism.
2. another is bought by the torch_xla._XLAC._replace_xla_tensor() change.
We already have the XLA_DISABLE_FUNCTIONALIZATION flag to workaround the 1st part, and then here we adopt the flag to workaround the 2nd part.
Test Plan:
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=TPU python test/test_train_mp_mnist_fsdp_with_ckpt.py --batch_size 16 --drop_last --num_epochs 2 --use_nested_fsdp --metrics_debug