pytorch
5f9939f6 - Introduce discontinuity to nested tensor (#80981)

Commit
2 years ago
Introduce discontinuity to nested tensor (#80981) Nested tensor used to assume the buffer memory to be contiguous. However, some operations can break that assumption: * reshape * transpose * slice To be able to access underlying tensors from discontinuous buffer, we need 3 metadata: * sizes of each tensor (`nested_size_tensor_`) * strides of each tensor (`nested_stride_tensor_`) * offset of each tensor (`offsets_`) so we access each tensor by `buffer.as_strided(size, stride, offset)` This pull request introduces the offsets metadata, then added reshape and transpose so that we can create discontinuous cases for testing. Unbind, select, dropout, softmax, bmm are refactored to provide tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80981 Approved by: https://github.com/jbschlosser
Author
Committer
Parents
Loading