pytorch
5a5258cb - Support the strided tensor on input for torch.cat (#46859)

Commit
4 years ago
Support the strided tensor on input for torch.cat (#46859) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46859 Current implementation, for non-contiguous, it will go to slow path. This change tries to enable fast path for non-contiguous input(up to 4-dim). Test Plan: #benchamark before ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : all # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1,1,1)_N2_dim0_cuda # Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 17.126 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 20.652 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(128,1024,2)_N2_dim1_cuda # Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 20.412 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim0_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 48.265 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1025,1023,2)_N2_dim1_cuda # Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 52.964 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim2_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 71.111 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f8a3cdc2440>,111,65]_N5_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f8a3cdc2440>, 111, 65], N: 5, dim: 0, device: cuda Forward Execution Time (us) : 39.492 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[96,<function<lambda>at0x7f8a3cdc2b90>,64]_N5_dim1_cuda # Input: sizes: [96, <function <lambda> at 0x7f8a3cdc2b90>, 64], N: 5, dim: 1, device: cuda Forward Execution Time (us) : 31.596 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[128,64,<function<lambda>at0x7f880e7db3b0>]_N5_dim2_cuda # Input: sizes: [128, 64, <function <lambda> at 0x7f880e7db3b0>], N: 5, dim: 2, device: cuda Forward Execution Time (us) : 66.668 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f880e7db5f0>,32,64]_N50_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f880e7db5f0>, 32, 64], N: 50, dim: 0, device: cuda Forward Execution Time (us) : 54.562 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[32,<function<lambda>at0x7f880e7db680>,64]_N50_dim1_cuda # Input: sizes: [32, <function <lambda> at 0x7f880e7db680>, 64], N: 50, dim: 1, device: cuda Forward Execution Time (us) : 53.255 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[33,65,<function<lambda>at0x7f880e7db710>]_N50_dim2_cuda # Input: sizes: [33, 65, <function <lambda> at 0x7f880e7db710>], N: 50, dim: 2, device: cuda Forward Execution Time (us) : 69.771 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda # Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 98.438 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda # Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda Forward Execution Time (us) : 115.045 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda # Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda Forward Execution Time (us) : 476.497 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f880e7db7a0>]_N100_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f880e7db7a0>], N: 100, dim: 0, device: cuda Forward Execution Time (us) : 86.307 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f880e7db830>]_N1000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f880e7db830>], N: 1000, dim: 0, device: cuda Forward Execution Time (us) : 453.269 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f880e7db8c0>]_N2000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f880e7db8c0>], N: 2000, dim: 0, device: cuda Forward Execution Time (us) : 935.365 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f880e7db950>]_N3000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f880e7db950>], N: 3000, dim: 0, device: cuda Forward Execution Time (us) : 1355.937 ``` after ``` WARNING:2020-11-01 21:14:23 3332963:3336757 EventProfilerController.cpp:143] (x1) Lost sample due to delays (ms): 488, 11, 4121, 0 # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : all # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1,1,1)_N2_dim0_cuda # Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 17.174 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 20.399 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(128,1024,2)_N2_dim1_cuda # Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 23.349 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim0_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 47.847 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1025,1023,2)_N2_dim1_cuda # Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 53.463 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim2_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 72.789 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fd5b5567710>,111,65]_N5_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fd5b5567710>, 111, 65], N: 5, dim: 0, device: cuda Forward Execution Time (us) : 39.747 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[96,<function<lambda>at0x7fd5b56b1320>,64]_N5_dim1_cuda # Input: sizes: [96, <function <lambda> at 0x7fd5b56b1320>, 64], N: 5, dim: 1, device: cuda Forward Execution Time (us) : 31.814 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[128,64,<function<lambda>at0x7fd3a2289680>]_N5_dim2_cuda # Input: sizes: [128, 64, <function <lambda> at 0x7fd3a2289680>], N: 5, dim: 2, device: cuda Forward Execution Time (us) : 67.202 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fd3a2289710>,32,64]_N50_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fd3a2289710>, 32, 64], N: 50, dim: 0, device: cuda Forward Execution Time (us) : 65.229 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[32,<function<lambda>at0x7fd3a22897a0>,64]_N50_dim1_cuda # Input: sizes: [32, <function <lambda> at 0x7fd3a22897a0>, 64], N: 50, dim: 1, device: cuda Forward Execution Time (us) : 60.843 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[33,65,<function<lambda>at0x7fd3a2289830>]_N50_dim2_cuda # Input: sizes: [33, 65, <function <lambda> at 0x7fd3a2289830>], N: 50, dim: 2, device: cuda Forward Execution Time (us) : 69.756 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda # Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 98.222 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda # Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda Forward Execution Time (us) : 112.521 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda # Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda Forward Execution Time (us) : 477.736 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fd3a22898c0>]_N100_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fd3a22898c0>], N: 100, dim: 0, device: cuda Forward Execution Time (us) : 50.617 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fd3a2289950>]_N1000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fd3a2289950>], N: 1000, dim: 0, device: cuda Forward Execution Time (us) : 461.631 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fd3a22899e0>]_N2000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fd3a22899e0>], N: 2000, dim: 0, device: cuda Forward Execution Time (us) : 840.469 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7fd3a2289a70>]_N3000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7fd3a2289a70>], N: 3000, dim: 0, device: cuda Forward Execution Time (us) : 1317.866 ``` Reviewed By: ngimel Differential Revision: D24527676 fbshipit-source-id: 83d6431e59fa7e1748292b37f5d1fa4ab6242299
Author
Parents
Loading