pytorch
20665235 - Fix `ShardedTensorMetadata.tensor_properties` for Python 3.11 (#91795)

Commit
2 years ago
Fix `ShardedTensorMetadata.tensor_properties` for Python 3.11 (#91795) The `tensor_properties` field of the `ShardedTensorMetadata` dataclass is a reference to a `TensorProperties` object. However, the field is set to `field(default=TensorProperties())` instead of `field(default_factory=TensorProperties)`. This causes an error when using Python 3.11 or later: ```python ValueError: mutable default <class 'torch.distributed._shard.sharded_tensor.metadata.TensorProperties'> for field tensor_properties is not allowed: use default_factory ``` This change in dataclass behavior was introduced in [bpo-44674: Use unhashability as a proxy for mutability for default dataclass __init__ arguments](https://github.com/python/cpython/pull/29867). The current use of `default` instead of `default_factory` also means that all `ShardedTensorMetadata` objects created without specifying `tensor_properties` will share the same `TensorProperties` object. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91795 Approved by: https://github.com/fduwjj
Author
Committer
Parents
Loading