xla
6f0b61e5 - [FSDPv2] Support MultiSlice (#7044)

Commit
1 year ago
[FSDPv2] Support MultiSlice (#7044) Summary: This pull request adds the multi-slice support for FSDPv2. Basically, the default setup is to use the dcn axis as the data axis, and it means we only do data parallel over multi-slices. In the future, we could also support FSDP over mutli-slices. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py
Author
Parents
Loading