Do not materialize entire randperm in RandomSampler (#103339)
In our DDP training workloads, each rank was initializing a `RandomSampler` for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope, `gc.collect` calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we call `torch.randperm(n).tolist()`, we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection.
This PR swaps out the `.tolist()` call with a `.numpy()` call and manually calls `.item()` on each element as it is being requested. This has two benefits:
1. The first call to `RandomSampler::__next__` should be about twice as fast, since `.numpy` does not copy the contents of the original tensor
2. The runtime of `gc.collect()` calls no longer scales linearly with the size of the dataset passed to `RandomSampler`
I've attached some `timeit` samples to illustrate the speedups with this Pr:
```
Main (no GC): 51.72115747816861
Main (10 GC calls) 83.61965207383037
PR (no GC) 33.06403830461204
PR (10 GC calls) 33.959467427805066
```
Code
```python
from timeit import timeit
baseline_no_gc = """
import torch
n = int(1e9)
steps = n // 100
x = torch.randperm(n).tolist()
x_iter = iter(x)
for i in range(steps):
next(x_iter)
"""
baseline_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10
x = torch.randperm(n).tolist()
x_iter = iter(x)
for i in range(steps):
next(x_iter)
if i % gc_every == 0:
gc.collect()
"""
numpy_no_gc = """
import torch
n = int(1e9)
steps = n // 100
x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)
for i in range(steps):
next(x_iter)
"""
numpy_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10
x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)
for i in range(steps):
next(x_iter)
if i % gc_every == 0:
gc.collect()
"""
if __name__ == "__main__":
print("Main (no GC): ", timeit(baseline_no_gc, number=1))
print("Main (10 GC calls)", timeit(baseline_gc, number=1))
print("PR (no GC)", timeit(numpy_no_gc, number=1))
print("PR (10 GC calls)", timeit(numpy_gc, number=1))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103339
Approved by: https://github.com/kit1980