Fix overlap-comm buffer lifetimes (#7965)
This PR fixes a ZeRO 1/2 overlap-comm correctness issue.
When comparing loss values, we found that only ZeRO2 shows nan as a
loss.
- zero1: 11.201002 -> 11.165665 -> 11.213738 -> 11.121310
- zero2: 11.201002 -> 11.165665 -> nan
- zero3: 11.201002 -> 11.165665 -> 11.204460 -> 11.121443
Here is what we found:
In `allreduce_and_copy_with_multiple_ranks()` and
`allreduce_and_copy()`, the reduction result and copied destination
buffers were used on the reduction stream without recording that stream
on the underlying storage, allowing the caching allocator to recycle
that storage before the queued comm/copy work had completed.
This could impact also ZeRO1 though we only encountered the issue with
ZeRO2.
This PR adds `record_stream` to ensure the buffer is not freed until the
queued work is done.
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>