pytorch
d80174e2 - Do not materialize entire randperm in RandomSampler (#103339)

Commit
1 year ago
Do not materialize entire randperm in RandomSampler (#103339) In our DDP training workloads, each rank was initializing a `RandomSampler` for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope, `gc.collect` calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we call `torch.randperm(n).tolist()`, we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection. This PR swaps out the `.tolist()` call with a `.numpy()` call and manually calls `.item()` on each element as it is being requested. This has two benefits: 1. The first call to `RandomSampler::__next__` should be about twice as fast, since `.numpy` does not copy the contents of the original tensor 2. The runtime of `gc.collect()` calls no longer scales linearly with the size of the dataset passed to `RandomSampler` I've attached some `timeit` samples to illustrate the speedups with this Pr: ``` Main (no GC): 51.72115747816861 Main (10 GC calls) 83.61965207383037 PR (no GC) 33.06403830461204 PR (10 GC calls) 33.959467427805066 ``` Code ```python from timeit import timeit baseline_no_gc = """ import torch n = int(1e9) steps = n // 100 x = torch.randperm(n).tolist() x_iter = iter(x) for i in range(steps): next(x_iter) """ baseline_gc = """ import torch import gc n = int(1e9) steps = n // 100 gc_every = steps // 10 x = torch.randperm(n).tolist() x_iter = iter(x) for i in range(steps): next(x_iter) if i % gc_every == 0: gc.collect() """ numpy_no_gc = """ import torch n = int(1e9) steps = n // 100 x = torch.randperm(n).numpy() x_iter = (i.item() for i in x) for i in range(steps): next(x_iter) """ numpy_gc = """ import torch import gc n = int(1e9) steps = n // 100 gc_every = steps // 10 x = torch.randperm(n).numpy() x_iter = (i.item() for i in x) for i in range(steps): next(x_iter) if i % gc_every == 0: gc.collect() """ if __name__ == "__main__": print("Main (no GC): ", timeit(baseline_no_gc, number=1)) print("Main (10 GC calls)", timeit(baseline_gc, number=1)) print("PR (no GC)", timeit(numpy_no_gc, number=1)) print("PR (10 GC calls)", timeit(numpy_gc, number=1)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/103339 Approved by: https://github.com/kit1980
Author
Committer
Parents
Loading