pytorch
31351c61 - [FSDP] Tighten post-bwd cast to `reduce_dtype` (#90615)

Commit
2 years ago
[FSDP] Tighten post-bwd cast to `reduce_dtype` (#90615) This lowers the `reduce_dtype` retrieval to the `handle` instead of the `state` in preparation for `fully_shard`, and this adds a guard to avoid a no-op `to()` call. Note that this change pretty much gets overridden in following PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90615 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading