pytorch
04bd9d75 - [DDP] Add API to get model parameters in hook (#61637)

Commit
3 years ago
[DDP] Add API to get model parameters in hook (#61637) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61637 To support running optimizer as a communication hook, add API to retrieve the model parameters. The API returns a `dict[idx -> tensor]` where `idx` is the intra bucket index of gradient tensor and thus the same index of `perParameterTensors`. The API can be used as follows to retrieve the model parameters: ``` per_param_grad_tensors = bucket.get_per_parameter_tensors() idx_to_model_params = bucket.get_grad_index_to_variable_mapping() for grad_tensor_idx, model_param in idx_to_model_params.items(): self.assertEqual(model_param.grad, per_param_grad_tensors[grad_tensor_idx]) ``` This provides a way for comm. hook developer to retrieve model parameters within a hook. In the next diffs, we will use this to run optimizer as a DDP comm. hook. ghstack-source-id: 133768666 Test Plan: CI Reviewed By: SciPioneer Differential Revision: D29691418 fbshipit-source-id: 4bfa824768a5850f73ee330017e2bcc29ceb7edc
Author
Parents
Loading