pytorch
2491aa53 - Make DataParallel generic (#102455)

Commit
1 year ago
Make DataParallel generic (#102455) Fixes #102441 improves type hinting of the module attribute, since it can easily be bound in `DataParallel.__init__` ```python from torch.nn import DataParallel class MyModule(Module): ... my_data_parallel = DataParallel(MyModule(), device_ids=[0, 1, 2]) reveal_type(my_data_parallel) # Type of "my_data_parallel" is "DataParallel[MyModule]" reveal_type(my_data_parallel.module) # Type of "my_data_parallel.module" is "MyModule" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/102455 Approved by: https://github.com/Skylion007
Author
Committer
Parents
Loading