[LTC] Add custom lazy tensor save function (#83294)
We need a custom `save` function for checkpointing a lazy model, similar to what exists in PyTorch/XLA:
https://github.com/pytorch/xla/blob/3eb8a9d9eb4ebb0b064461c3704650241625654e/torch_xla/core/xla_model.py#L994
The purpose of this function is to move any lazy tensors to CPU before saving the checkpoint.
The way I implemented it was to create a general structure visitor, adapted from a function that we use quite often in Cerebras internal repositories. If there is a better tool already available in PyTorch that does the same things, I'm open to suggestions.
CC: @wconstab @Krovatkin @JackCaoG
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83294
Approved by: https://github.com/wconstab