[FSDP] Replace all .data usages with modern approaches (#4485)
Summary:
The .data API is deprecated by PyTorch. Let's replace all .data usages with modern approaches.
1. we replace .data = with .set_. We also lower .set_ within this PR.
2. we then remove other .data given there is no needs to detach.
There is one caveat as we now requires torch.no_grad() to run inferences. Otherwise, it will introduce weird error. Need to follow up later.
Test Plan:
1. PJRT_DEVICE=TPU python test/test_train_mp_mnist_fsdp_with_ckpt.py --batch_size 16 --drop_last --num_epochs 2 --use_nested_fsdp
Expected Max Accuracy: 98.88%+
2. python test/test_operations.py -v -k test_set