Use high precision accmulate buffer for bf16 accmulation (#84402)
Accumulation operation is not friendly to BFloat16 because its mantissa part is only 7bits while the operand could not impact the final result if it is very small.
Take `a += b` as an example, `a` will become bigger with running the computation. And then, the variance between `a` and `b` also is being huge, the `b` would not impact `a`.
Hence, the best practice is to use FP32 to do accumulation and then convert back to BF16 as long as the accumulation is finished. This PR also follows the best practice.
We extend the `ReduceOp` by adding `accumulation` buffer and recording the result buffer and `Reducer`'s operand. Because we need to replace the original `ReduceOp` with a new `ReduceOp` to use `accumulation` buffer for reduction.
- Extend `ReduceOp` by adding `accumulation` buffer and recording the result buffer and `Reducer`'s operand - [PR change](https://github.com/pytorch/pytorch/pull/84402/files#diff-0f4be13525117d5c49c69bd18e92eb15dda36b5a59b7a10c7e1114f5cac10afbR225-R229)
- Replace the original `ReduceOp` with a new `ReduceOp` to use `accumulation` buffer for reduction - [PR change](https://github.com/pytorch/pytorch/pull/84402/files#diff-fac6725328dc01e235944c7afc9f29c804488973c02c25ecd93d562884d959b3R26-R36)
- Cast the accumulation buffer from FP32 to BF16 and write back to the result buffer - [PR change](https://github.com/pytorch/pytorch/pull/84402/files#diff-fac6725328dc01e235944c7afc9f29c804488973c02c25ecd93d562884d959b3R62-R67)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84402
Approved by: https://github.com/frank-wei