Better Handling of Storage Cache (#98254)
Because we do not persist output memory of cudagraphs, we need to reconstruct tensors at their correct memory locations after we've done a run. We were using a storage cache for that but it had a couple of issues:
- If the a data ptr existed in the cache, we should only reuse the corresponding storage if the storage hadn't died
- didnt work across separate nodes. While you wouldn't think this would be an issue, it was in testing HF.
- StorageWeakRef maintains whether the Storage C++ object remains allocated, not whether the corresponding memory has been deallocated. In order to use them to track memory deallocations we must maintain a single StorageWeakRef for all Storages that reference that memory (even if we are constructing Storages that do not have a deallocator function).
This PR a singlestorage_cache as we execute any tree path. When we retrieve a storage from the cache we
check that it is still alive, and we hash based on both observed recording data ptr and storageimpl weak ref.
Update to use a single storage cache across all executions in a path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98254
Approved by: https://github.com/jansel