xla
cb02d52b - [FSDP] Adopts XLA_DISABLE_FUNCTIONALIZATION (#4806)

Commit
2 years ago
[FSDP] Adopts XLA_DISABLE_FUNCTIONALIZATION (#4806) Summary: Functionalization introduces a huge memory regression around 59% for GPT-2 with FSDP. The regression comes from two parts: 1. one is just introduced by funtionalization mechanism. 2. another is bought by the torch_xla._XLAC._replace_xla_tensor() change. We already have the XLA_DISABLE_FUNCTIONALIZATION flag to workaround the 1st part, and then here we adopt the flag to workaround the 2nd part. Test Plan: XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=TPU python test/test_train_mp_mnist_fsdp_with_ckpt.py --batch_size 16 --drop_last --num_epochs 2 --use_nested_fsdp --metrics_debug
Author
Parents
Loading