pytorch
cb817d61 - Fix endian handling in THPStorage_fromBuffer (#92834)

Commit
1 year ago
Fix endian handling in THPStorage_fromBuffer (#92834) Fixes #92831 This PR fixes a test failure of `TestTorch.test_from_buffer` on a big-endian machine. The root cause of this failure is that current `THPStorage_fromBuffer` does not perform endian handling correctly on a big-endian. In `THPStorage_fromBuffer`, the given buffer is stored as machine native-endian. Thus, if the specified byte order (e.g. `big`) is equal to machine native-endian, swapping elements should not be performed. However, in the current implementation, [`decode*BE()`](https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/byte_order.cpp#L72-L109) always swaps elements regardless of machine native-endian (i.e. these methods assume buffer is stored as little-endian). Thus, this PR uses the following approaches: - if the specified byte order (e.g. `big`) is equal to machine native-endian, call `decode*LE()` that does not swap elements by passing `torch::utils::THP_LITTLE_ENDIAN` to `THP_decode*Buffer()`. - if the specified byte order (e.g. `big`) is not equal to machine native-endian, call `decode*BE()` that always swap elements by passing `torch::utils::THP_BIG_ENDIAN` to `THP_decode*Buffer()`. After applying this PR to the master branch, I confirmed that the test passes on a big-endian machine. ``` % python test/test_torch.py TestTorch.test_from_buffer /home/ishizaki/PyTorch/master/test/test_torch.py:6367: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage() self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4]) ... /home/ishizaki/PyTorch/master/test/test_torch.py:6396: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage() self.assertEqual(bytes.tolist(), [1, 2, 3, 4]) . ---------------------------------------------------------------------- Ran 1 test in 0.021s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/92834 Approved by: https://github.com/ezyang
Author
Committer
Parents
Loading