pytorch
a0ce8da2 - Fix DistributedSampler mem usage on large datasets (#51841)

Commit
3 years ago
Fix DistributedSampler mem usage on large datasets (#51841) Summary: The current implementation of DistributedSampler generates a python list to hold all of the indices, and then returns a slice of this list for the given rank (creating a partial copy of the list). When the underlying dataset is large, both of these choices waste a large amount of memory. It is much more efficient to create a tensor to hold the indices, and then index into that tensor instead of creating slices. In the case of a sampler with `shuffle=False`, it would be possible to avoid creating the `indices` tensor entirely (since the index will always match the value), but I have opted instead here to keep the implementation as similar to the existing version as possible. One possible benefit of this approach is that memory usage will not significantly change based on changing this parameter. Still, it might be better to simply return the indices directly without the underlying array. Additionally, the logic around calculating the number of samples is unnecessarily complex. When dropping the last batch, this can be a simple floor division. In a simple test script which creates a sampler for a dataset with a 100,000,000 items, memory usage is reduced 98% compared to the existing implementation. Fixes https://github.com/pytorch/pytorch/issues/45427 Pull Request resolved: https://github.com/pytorch/pytorch/pull/51841 Reviewed By: albanD Differential Revision: D28240105 Pulled By: rohan-varma fbshipit-source-id: 4c6aa493d0f75c07ec14c98791b3a531300fb1db
Author
Parents
Loading