Use tree-based sum for floats to avoid numerical instability (#39516)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/38716, fixes https://github.com/pytorch/pytorch/issues/37234
This algorithm does the summation along a single axis with multiple "levels" of accumulator, each of which is designed to hold the sum of an order of magnitude more values than the previous.
e.g. if there are 2^16 elements, the first level will hold the sum of 2^4 elements, and so on in increasing powers of 2: 2^4, 2^8, 2^12 and finally 2^16.
This limits the differences in magnitude of the partial results being added together, and so we don't lose accuracy as the axis length increases.
WIP to write a vectorized version.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39516
Reviewed By: ezyang
Differential Revision: D22106251
Pulled By: ngimel
fbshipit-source-id: b56de4773292439dbda62b91f44ff37715850ae9