[CUDA Pinned Memory] Event recording with non-blocking copies should track the storage context, not the tensor data pointer (#68749)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68749
The logic for asynchronous copies (either HtoD or DtoH) using cudaMemcpyAsync relies on recording an event with the caching host allocator to notify it that a given allocation has been used on a stream - and thus it should wait for that stream to proceed before reusing the host memory.
This tracking is based on the allocator maintaining a map from storage allocation pointers to some state.
If we try to record an event for a pointer we don't understand, we will silently drop the event and ignore it (https://github.com/pytorch/pytorch/blob/9554ebe44e6e73dc75105d4935d41e626e03299b/aten/src/ATen/cuda/CachingHostAllocator.cpp#L171-L175).
Thus, if we use the data_ptr of a Tensor instead of the storage allocation, then reasonable code can lead to incorrectness due to missed events.
One way this can occur is simply by slicing a tensor into sub-tensors - which have different values of `data_ptr()` but share the same storage, for example:
```
image_batch = torch.randn(M, B, C, H, W).pin_memory()
for m in range(M):
sub_batch = image_batch[m].cuda(non_blocking=True)
# sub_batch.data_ptr() != image_batch.data_ptr() except for m == 0.
# however, sub_batch.storage().data_ptr() == image_batch.storage().data_ptr() always.
```
Therefore, we instead use the storage context pointer when recording events, as this is the same state that is tracked by the caching allocator itself. This is a correctness fix, although it's hard to determine how widespread this issue is.
Using the storage context also allows us to use a more efficient structure internally to the caching allocator, which will be sent in future diffs.
Test Plan: Test added which demonstrates the issue, although it's hard to demonstrate the race explicitly.
Reviewed By: ngimel
Differential Revision: D32588785
fbshipit-source-id: d87cc5e49ff8cbf59052c3c97da5b48dd1fe75cc