pytorch
c46af25b - Initialize optimizer in dynamo to avoid graph break and tracing slowness (#102640)

Commit
1 year ago
Initialize optimizer in dynamo to avoid graph break and tracing slowness (#102640) On calls to `_init_group` rather than tracing through it, extract python values from the arguments, and call the initialization. This avoids having to trace this function which is very slow with large parameters, and also avoids graph breaking on it. This is sound in this case because the state is only initialized once in the eager case. Guards on the state and params are generated explicitly rather than via tracing the initialization. Caveats: `_init_group` also gathers various state tensors into lists via mutating list arguments to pass to the functional optimizer implementation. These state tensors exist on the optimizer itself, but we don't know exactly how the gathering is done and which tensors correspond to which attributes of the optimizer module (each optimizer has different states). To rectify this, we keep weak_ptrs to all of the tensors collected in the lists in globals (similar to how parameter keys are stored for dictionaries). These pointers are guaranteed to be alive as long as the optimizer object is alive if the internal state is not interfered with and they are guarded with weakref guards Pull Request resolved: https://github.com/pytorch/pytorch/pull/102640 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading