Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
There are two scenarios:
* Scenario 1: The checkpoint was saved with pytorch < 1.6
* Scenario 2: The checkpoint was saved with pytorch >= 1.6
Repro Scenario 1:
```python
from torch._subclasses import fake_tensor
import transformers
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")
```
Error:
```bash
Some weights of the model checkpoint at sshleifer/tiny-gpt2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:463 in │
│ load_state_dict │
│ │
│ 460 │ │ │ ) │
│ 461 │ │ return safe_load_file(checkpoint_file) │
│ 462 │ try: │
│ ❱ 463 │ │ return torch.load(checkpoint_file, map_location="cpu") │
│ 464 │ except Exception as e: │
│ 465 │ │ try: │
│ 466 │ │ │ with open(checkpoint_file) as f: │
│ │
│ /opt/pytorch/torch/serialization.py:1030 in load │
│ │
│ 1027 │ │ │ │ return _legacy_load(opened_file, map_location, _weights_only_unpickler, │
│ 1028 │ │ │ except RuntimeError as e: │
│ 1029 │ │ │ │ raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None │
│ ❱ 1030 │ │ return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args │
│ 1031 │
│ 1032 │
│ 1033 # Register pickling support for layout instances such as │
│ │
│ /opt/pytorch/torch/serialization.py:1258 in _legacy_load │
│ │
│ 1255 │ _sys_info = pickle_module.load(f, **pickle_load_args) │
│ 1256 │ unpickler = UnpicklerWrapper(f, **pickle_load_args) │
│ 1257 │ unpickler.persistent_load = persistent_load │
│ ❱ 1258 │ result = unpickler.load() │
│ 1259 │ │
│ 1260 │ deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) │
│ 1261 │
│ │
│ /opt/pytorch/torch/_utils.py:201 in _rebuild_tensor_v2 │
│ │
│ 198 def _rebuild_tensor_v2( │
│ 199 │ storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None │
│ 200 ): │
│ ❱ 201 │ tensor = _rebuild_tensor(storage, storage_offset, size, stride) │
│ 202 │ tensor.requires_grad = requires_grad │
│ 203 │ if metadata: │
│ 204 │ │ set_tensor_metadata(tensor, metadata) │
│ │
│ /opt/pytorch/torch/_utils.py:180 in _rebuild_tensor │
│ │
│ 177 def _rebuild_tensor(storage, storage_offset, size, stride): │
│ 178 │ # first construct a tensor with the correct dtype/device │
│ 179 │ t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device) │
│ ❱ 180 │ return t.set_(storage._untyped_storage, storage_offset, size, stride) │
│ 181 │
│ 182 │
│ 183 def get_tensor_metadata(tensor): │
│ │
│ /opt/pytorch/torch/utils/_stats.py:20 in wrapper │
│ │
│ 17 │ │ if fn.__qualname__ not in simple_call_counter: │
│ 18 │ │ │ simple_call_counter[fn.__qualname__] = 0 │
│ 19 │ │ simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1 │
│ ❱ 20 │ │ return fn(*args, **kwargs) │
│ 21 │ return wrapper │
│ 22 │
│ │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1160 in __torch_dispatch__ │
│ │
│ 1157 │ def __torch_dispatch__(self, func, types, args=(), kwargs=None): │
│ 1158 │ │ assert self not in _get_current_dispatch_mode_stack(), func │
│ 1159 │ │ try: │
│ ❱ 1160 │ │ │ return self.dispatch(func, types, args, kwargs) │
│ 1161 │ │ except TypeError: │
│ 1162 │ │ │ log.exception("fake tensor raised TypeError") │
│ 1163 │ │ │ raise │
│ │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1318 in dispatch │
│ │
│ 1315 │ │ │
│ 1316 │ │ # we are falling through to running non constant tensors, any input constant tha │
│ 1317 │ │ # is written to must be invalidated │
│ ❱ 1318 │ │ self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) │
│ 1319 │ │ │
│ 1320 │ │ # Try for fastpath │
│ 1321 │ │ if has_symbolic_sizes: │
│ │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1557 in invalidate_written_to_constants │
│ │
│ 1554 │ │ any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) │
│ 1555 │ │ if any_constant and get_schema_info(func).is_mutable(): │
│ 1556 │ │ │ schema_info = get_schema_info(func) │
│ ❱ 1557 │ │ │ _, new_kwargs = normalize_function( │
│ 1558 │ │ │ │ func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True │
│ 1559 │ │ │ ) │
│ 1560 │ │ │ for k, v in new_kwargs.items(): │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:297 in normalize_function │
│ │
│ 294 │ │ new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, │
│ 295 │ else: │
│ 296 │ │ assert callable(target) │
│ ❱ 297 │ │ torch_op_schemas = get_signature_for_torch_op(target) │
│ 298 │ │ matched_schemas = [] │
│ 299 │ │ if torch_op_schemas: │
│ 300 │ │ │ # Iterate through all of the schema until we find one that matches │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in get_signature_for_torch_op │
│ │
│ 164 │ │ │ return (None, None) if return_schemas else None │
│ 165 │ │ schemas = torch._C._jit_get_schemas_for_operator(aten_fn) │
│ 166 │ │
│ ❱ 167 │ signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] │
│ 168 │ return (signatures, schemas) if return_schemas else signatures │
│ 169 │
│ 170 @compatibility(is_backward_compatible=False) │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in <listcomp> │
│ │
│ 164 │ │ │ return (None, None) if return_schemas else None │
│ 165 │ │ schemas = torch._C._jit_get_schemas_for_operator(aten_fn) │
│ 166 │ │
│ ❱ 167 │ signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] │
│ 168 │ return (signatures, schemas) if return_schemas else signatures │
│ 169 │
│ 170 @compatibility(is_backward_compatible=False) │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:70 in _torchscript_schema_to_signature │
│ │
│ 67 │ from inspect import Parameter │
│ 68 │ parameters : List[Parameter] = [] │
│ 69 │ for arg in ts_schema.arguments: │
│ ❱ 70 │ │ arg_type = _torchscript_type_to_python_type(arg.type) │
│ 71 │ │ default = arg.default_value if arg.has_default_value() else Parameter.empty │
│ 72 │ │ # TODO: Figure out if this is safe. It seems like when generating the type signa │
│ 73 │ │ # PythonArgParser, we emit signatures with `input` instead of `self` as the firs │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:64 in _torchscript_type_to_python_type │
│ │
│ 61 │ eval'ing the annotation_str. _type_eval_globals sets up expressions │
│ 62 │ like "List" and "Future" to map to actual types (typing.List and jit.Future) │
│ 63 │ """ │
│ ❱ 64 │ return eval(ts_type.annotation_str, _type_eval_globals) │
│ 65 │
│ 66 def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Sig │
│ 67 │ from inspect import Parameter │
│ <string>:1 in <module> │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NameError: name 'Storage' is not defined
During handling of the above exception, another exception occurred:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:467 in │
│ load_state_dict │
│ │
│ 464 │ except Exception as e: │
│ 465 │ │ try: │
│ 466 │ │ │ with open(checkpoint_file) as f: │
│ ❱ 467 │ │ │ │ if f.read(7) == "version": │
│ 468 │ │ │ │ │ raise OSError( │
│ 469 │ │ │ │ │ │ "You seem to have cloned a repository without having git-lfs ins │
│ 470 │ │ │ │ │ │ "git-lfs and run `git lfs install` followed by `git lfs pull` in │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/codecs.py:322 in decode │
│ │
│ 319 │ def decode(self, input, final=False): │
│ 320 │ │ # decode input (taking the buffer into account) │
│ 321 │ │ data = self.buffer + input │
│ ❱ 322 │ │ (result, consumed) = self._buffer_decode(data, self.errors, final) │
│ 323 │ │ # keep undecoded input until the next call │
│ 324 │ │ self.buffer = data[consumed:] │
│ 325 │ │ return result │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte
During handling of the above exception, another exception occurred:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/pytorch/bug_repro.py:16 in <module> │
│ │
│ 13 fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2") │
│ 14 assert fake_model is not None │
│ 15 with fake_mode: │
│ ❱ 16 │ fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2") # raises │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:484 in │
│ from_pretrained │
│ │
│ 481 │ │ │ ) │
│ 482 │ │ elif type(config) in cls._model_mapping.keys(): │
│ 483 │ │ │ model_class = _get_model_class(config, cls._model_mapping) │
│ ❱ 484 │ │ │ return model_class.from_pretrained( │
│ 485 │ │ │ │ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, │
│ 486 │ │ │ ) │
│ 487 │ │ raise ValueError( │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:2604 in │
│ from_pretrained │
│ │
│ 2601 │ │ if from_pt: │
│ 2602 │ │ │ if not is_sharded and state_dict is None: │
│ 2603 │ │ │ │ # Time to load the checkpoint │
│ ❱ 2604 │ │ │ │ state_dict = load_state_dict(resolved_archive_file) │
│ 2605 │ │ │ │
│ 2606 │ │ │ # set dtype to instantiate the model under: │
│ 2607 │ │ │ # 1. If torch_dtype is not None, we use that dtype │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:479 in │
│ load_state_dict │
│ │
│ 476 │ │ │ │ │ │ "model. Make sure you have saved the model properly." │
│ 477 │ │ │ │ │ ) from e │
│ 478 │ │ except (UnicodeDecodeError, ValueError): │
│ ❱ 479 │ │ │ raise OSError( │
│ 480 │ │ │ │ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_f │
│ 481 │ │ │ │ f"at '{checkpoint_file}'. " │
│ 482 │ │ │ │ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please s │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OSError: Unable to load weights from pytorch checkpoint file for '/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin' at
'/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set
from_tf=True.
```
Repro scenario 2:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during torch.load (when fake mode is active) by changing the storage's device to 'meta'.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang, https://github.com/atalman