compute reduction intermediate buffer size in elements (#63885)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63869
`iter` strides are in bytes, and we are additionally multiplying size computed using those strides by `sizeof(arg_t)`. Computing `output_memory_size` in elements should be enough.
This doesn't fix the still real problem of allocating large intermediate tensor, but it makes this tensor smaller by typically a factor of 4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63885
Reviewed By: mruberry
Differential Revision: D30526034
Pulled By: ngimel
fbshipit-source-id: 0aca7f887974b7776e380463bbd82d32a5786ee8
Author
Natalia Gimelshein