Making wrapper tensor subclass to work in serialization (#2440)
* Making wrapper tensor subclass to work in huggingface_hub serialization (non-safetensor)
Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.
Test Plan:
tested with script in https://github.com/huggingface/transformers/issues/32364
Reviewers:
Subscribers:
Tasks:
Tags:
* add tests
* update signature to include new changes for tensor subclass
* add torch version checks and move around import
* more fixes
* tested with torch 2.0.0 and 2.5.0
* remove torch_version_at_least from _torch.py
* simplify code for checking if tensor subclass is available or not
* minor fix
* addressing comments and run tests with torch 2.4.0
* some linting
* add test_split_torch_state_dict_into_shards for tensor subclass state dict
* lint
* style
* quality
---------
Co-authored-by: Lucain <lucain@huggingface.co>
Co-authored-by: Lucain Pouget <lucainp@gmail.com>