pytorch
3bbb36e0 - Update linspace types (#32218)

Commit
4 years ago
Update linspace types (#32218) Summary: Changes the linspace functions to be more consistent as requested in https://github.com/pytorch/pytorch/issues/31991. The code has also been updated to avoid an early rounding error; the line `scalar_t step = (scalar_end - scalar_start) / static_cast<static_t>(steps-1)` can result in `step = 0` for integer scalars, and this gives unintended results. I examined the new output using ``` import torch types = [torch.uint8, torch.int8, torch.short, torch.int, torch.long, torch.half, torch.float, torch.double] print('Testing linspace:') for type in types: print(type, torch.linspace(-2, 2, 10, dtype=type)) ``` which returns ``` Testing linspace: torch.uint8 tensor([254, 254, 254, 255, 255, 0, 0, 1, 1, 2], dtype=torch.uint8) torch.int8 tensor([-2, -2, -2, -1, -1, 0, 0, 1, 1, 2], dtype=torch.int8) torch.int16 tensor([-2, -2, -2, -1, -1, 0, 0, 1, 1, 2], dtype=torch.int16) torch.int32 tensor([-2, -2, -2, -1, -1, 0, 0, 1, 1, 2], dtype=torch.int32) torch.int64 tensor([-2, -2, -2, -1, -1, 0, 0, 1, 1, 2]) torch.float16 tensor([-2.0000, -1.5557, -1.1113, -0.6670, -0.2227, 0.2227, 0.6660, 1.1113, 1.5547, 2.0000], dtype=torch.float16) torch.float32 tensor([-2.0000, -1.5556, -1.1111, -0.6667, -0.2222, 0.2222, 0.6667, 1.1111, 1.5556, 2.0000]) torch.float64 tensor([-2.0000, -1.5556, -1.1111, -0.6667, -0.2222, 0.2222, 0.6667, 1.1111, 1.5556, 2.0000], dtype=torch.float64) ``` which is the expected output: `uint8` overflows as it should, and the result of casting from a floating point to an integer is correct. This PR does not change the logspace function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32218 Differential Revision: D19544224 Pulled By: ngimel fbshipit-source-id: 2bbf2b8552900eaef2dcc41b6464fc39bec22e0b
Author
Parents
Loading