pytorch
5817695b - [pt2] Fix arange to match ATen behavior (#93353)

Commit
1 year ago
[pt2] Fix arange to match ATen behavior (#93353) Fixes #92676 `arange` infers the output dtype from the argument types, but in order to reduce falling back to ATen, inductor preferred to cast whole number float arguments to int which gave the wrong output dtype. Instead, this decomposes floating point arange into the prim equivalent for integers. This also changes the signature of `prims.arange` to ```python prims.iota(length, *, start, step, **factory_kwargs) ``` which only supports integers arguments. This is done because calculating the output size from `start, end, step` is surprisingly complex and liable to off by one errors so should not be duplicated in each backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/93353 Approved by: https://github.com/ngimel, https://github.com/lezcano
Author
Committer
Parents
Loading