Non-reentrant checkpointing hook fix (#5781)
This PR adds an extra condition to attach backward pass hooks to leaf
nodes only if Synchronisation or Profiling is enabled, as otherwise
these hooks are not necessary. Hook code below:
```
def after_backward_hook(_nonuse_grads):
"""the hook registered to all leaf tensors"""
nonlocal leaf_tensors, backward_visited_leaf_nodes
backward_visited_leaf_nodes += 1
if backward_visited_leaf_nodes == len(leaf_tensors):
see_memory_usage("After backward checkpointing code after backward", force=False)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
```
see_memory_usage is nevel used, as `force` is hardcoded to `False`. Thus
this hook only does any real work only when PROFILE_TIME or SYNCHRONIZE
is True. Otherwise it creates unnecessary function calls
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>