xla
3f47abab - use XLA patched linear in FSDP (fix #3811 and #3718) and expose options on param sharding dim and pinning memory (#3830)

Commit
3 years ago
use XLA patched linear in FSDP (fix #3811 and #3718) and expose options on param sharding dim and pinning memory (#3830) * use XLA patched linear in FSDP; expose options for padding in all_gather and pinning layout * directly patch `nn.Linear`'s forward method instead of `torch.nn.functional.linear` to be thread-safe in PJRT * introduce `shard_param_on_dim_0` to handle potential compiler issues * check `full_param.dim() == 1` instead of catching exceptions in `_consolidate_param` in state_dict_utils.py
Author
Parents
Loading