pytorch
9f1468ae - CyclicLR memory leak fix (#85462)

Commit
3 years ago
CyclicLR memory leak fix (#85462) Hi, we noticed in our team that by using CyclicLR, there is a problem with memory clearance on GPU (probably it will be the case without the GPU as well, but that was our use case) After initializing CyclicLR, GPU memory is not cleared even after the model, optimizer and scheduler are out of scope (e.g. reference count is zero). This is because `__init__` method inside `CyclicLR` creates reference to its own methods and it will not get removed until `gc.collect()` is called manually. This is a problem if people want to test multiple models in one run of a script, after testing the first model, second one will fail on `CUDA out of memory error` because the first one is not cleared from the memory. I propose a simple fix by using `weakref`, similarly as in `_LRScheduler` base class, but if you have any comments I am happy to change it. Here is the code to reproduce the bug: ``` import torch import weakref from transformers import DetrForObjectDetection class X: def __init__(self, optimizer): self.optimizer = optimizer # Will cause cyclic reference. self.func = self.dummy # Will work as expected, memory cleared after instance count is zero. # self.func = weakref.WeakMethod(self.dummy) def dummy(self, x): return 1. def test(): model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50') model.to('cuda') optimizer = torch.optim.Adam(model.parameters()) x = X(optimizer) test() print(f'{torch.cuda.memory_reserved()}, {torch.cuda.memory_allocated()}') # Should print (<some memory>, 0), but with cyclic reference, it will print (<some memory>, <some memory>). ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/85462 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading