[dtensor][TP] check funcol calls and improve doc for loss parallel (#121366)
Since CommDebugMode is fixed, we can check that loss parallel is working as expected.
Under loss parallel, the forward computation should invoke 3 all-reduces, and the backward computation should invoke no functional collectives.
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121366
Approved by: https://github.com/wanchaol