Fix DistributedSampler mem usage on large datasets (#51841)
Summary:
The current implementation of DistributedSampler generates a python list to hold all of the indices, and then returns a slice of this list for the given rank (creating a partial copy of the list). When the underlying dataset is large, both of these choices waste a large amount of memory. It is much more efficient to create a tensor to hold the indices, and then index into that tensor instead of creating slices.
In the case of a sampler with `shuffle=False`, it would be possible to avoid creating the `indices` tensor entirely (since the index will always match the value), but I have opted instead here to keep the implementation as similar to the existing version as possible. One possible benefit of this approach is that memory usage will not significantly change based on changing this parameter. Still, it might be better to simply return the indices directly without the underlying array.
Additionally, the logic around calculating the number of samples is unnecessarily complex. When dropping the last batch, this can be a simple floor division.
In a simple test script which creates a sampler for a dataset with a 100,000,000 items, memory usage is reduced 98% compared to the existing implementation.
Fixes https://github.com/pytorch/pytorch/issues/45427
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51841
Reviewed By: albanD
Differential Revision: D28240105
Pulled By: rohan-varma
fbshipit-source-id: 4c6aa493d0f75c07ec14c98791b3a531300fb1db