pytorch
6fa2d41d - Add mmap option to `torch.load` (#102549)

Commit
1 year ago
Add mmap option to `torch.load` (#102549) Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run <details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary> ``` # test_load_save_gpt.py from model import GPT import torch import time torch.manual_seed(5) # gpt2-xlarge 1558M parameters class GPTConfig: block_size: int = 1024 vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: int = 48 n_head: int = 25 n_embd: int = 1600 dropout: float = 0.0 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster def f(): model = GPT(GPTConfig()) state_dict = model.state_dict() start_saving = time.time() torch.save(state_dict, "gpt2-xlarge.pth") end_saving = time.time() if __name__ == "__main__": f() ``` </details> <details><summary><b>Click for script to load</b></summary> ``` # test_load_gpt.py import torch from model import GPT from test_load_save_gpt import GPTConfig import time import argparse def f(mmap, meta): device = 'meta' if meta else 'cpu' assign = True if meta else False with torch.device(device): model = GPT(GPTConfig()) start_loading = time.time() loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap) end_loading = time.time() print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading) model.load_state_dict(loaded_state_dict, assign=assign) end_load_state_dict = time.time() print("load_state_dict time: ", end_load_state_dict - end_loading) model.cuda() end_cuda = time.time() print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict) if __name__ == "__main__": parser = argparse.ArgumentParser(prog='load_gpt_xlarge') parser.add_argument('-m', '--mmap', action='store_true') parser.add_argument('-d', '--devicemeta', action='store_true') args = parser.parse_args() mmap = args.mmap meta = args.devicemeta f(mmap, meta) ``` </details> `python test_load_gpt.py` <img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29"> `python test_load_gpt.py --mmap` <img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af"> If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have `python test_load_gpt.py --mmap --devicemeta` <img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f"> \ \ Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by 1) running `make -f docker.Makefile` from `pytorch/` directory 2) `docker run -m 512m -it <image> bash` 3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image `python test_load_gpt.py` Docker will Kill the process due to OOM whereas `python test_load_gpt.py --mmap --devicemeta` <img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549 Approved by: https://github.com/albanD
Committer
Parents
Loading