[FSDP][1/N] Rework `clip_grad_norm_()` and tests (#87479)
This PR reworks FSDP's `clip_grad_norm_()` and its unit tests. The unit tests in `test_fsdp_core.py` still need to be revisited and will be done in follow-up work.
Some details in arbitrary order:
- This renames `_calc_grad_norm()` to `_get_grad_norm()`. This is to simplify our verb usage in method names. Otherwise, we may diverge to different verbs like "compute", "calculate", "get", "find" etc. I am open to discussion here.
- Because we call `torch.linalg.vector_norm()` as the underlying norm calculation subroutine, which can take infinity as input for the norm type, there is no reason to have a separate conditional branch for the infinity norm.
- This removes a host-device synchronization point from `clip_grad_norm_()` by using the same trick from `torch.nn.utils.clip_grad_norm_()`. This may improve throughput for workloads like metaseq, which computes gradient norms regularly.
- This returns the total norm from `clip_grad_norm_()` as mentioned in the docstring. Before nothing was returned.
- This rewrites the unit tests, which were slightly problematic. Much of the logic to verify gradient norms were computed correctly were exactly the same as the logic used to compute them in FSDP (i.e. `^p`, sum via all-reduce, `^(1/p)`). This defeats the purpose of unit testing. There were some other oddities like `input = torch.rand(14, 2, device=self.rank); in_data = torch.tensor(input[self.rank], device=self.rank)`, where we materialize a full `(14, 2)` shape but only ever use the first two rows (assuming world size 2).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87479
Approved by: https://github.com/rohan-varma