pytorch
3498fde2 - Add AccumulateType in AdaptiveAveragePooling3d.cu (#53607)

Commit
3 years ago
Add AccumulateType in AdaptiveAveragePooling3d.cu (#53607) Summary: Fixes https://github.com/pytorch/pytorch/issues/52719 - Changed the type(`scalar_t`) of intermediate results to `at::acc_type<scalar_t, true>` This issue occurs by decimal precision of the half precision. Follows test cases of upper issue, The value range of input tensors are [0, 1] because init by `rand`. And when the kernel size 1, summations all target values and divide numel of kernel https://github.com/pytorch/pytorch/blob/34d9278c1913d83db47a25f0cac71b69ab877c84/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu#L94-L95 When adding [0, 1] values, if `sum` more than 2048 then not changed values. ( Even if the value is small, the mored exact value is added, but there are still precision issues.) (https://en.wikipedia.org/wiki/Half-precision_floating-point_format) Benchmarks - In V100 32GB, Driver : 450.80, cuda 10.1 - faster than prev <details><summary>Script</summary><p> ```import torch from torch.utils.benchmark import Timer torch.manual_seed(0) kernel_sizes = [1, 3, 5, 7, 9, 11, 13] shapes = [(12, 12, 12), (16, 16, 16), (16, 32, 32), (16, 56, 56), (16, 112, 112)] def run(batch, channel): print(f"Batch : {batch}, Channel : {channel} / (diff, diff / numel, time)") head = "\t".join(f"{str(s):30s}" for s in ["k \ shape"] + shapes) print(head) for kernel_size in kernel_sizes: kernel_size = (kernel_size, kernel_size, kernel_size) pool = torch.nn.AdaptiveAvgPool3d(kernel_size) print(f"{str(kernel_size):30s}", end="\t") for shape in shapes: x_half = torch.rand([batch, channel, *shape], dtype=torch.half, device="cuda") x_float = x_half.float() y_half = pool(x_half) y_float = pool(x_float) timer = Timer("pool(x_half)", globals={"pool": pool, "x_half": x_half}) measurement = timer.blocked_autorange(min_run_time=5) diff = (y_float - y_half).abs().sum().item() diff = f"{diff:.4f}, {diff / y_half.numel():.6f}, {measurement.median * 1e6 :3.2f}us" print(f"{diff:30s}", end="\t") print("") run(1, 1) run(1, 3) run(1, 54) run(1, 16) run(8, 1) run(8, 16) run(8, 54) import torch m = torch.nn.AdaptiveAvgPool3d((1,1,1)) inputs = torch.rand([8,54,16,56,56]) inputs = inputs.cuda() inputs_2 = inputs.half() print("Float") out = m(inputs).float() print("half") out2 = m(inputs_2).float() print('Discepancies', torch.sum(torch.abs(out2- out)).item(), torch.sum(torch.abs(out2- out)).item() / out.numel() , out.numel()) print("Sum : ", torch.sum(inputs, dim=(2,3,4))[0, 0], torch.sum(inputs_2, dim=(2,3,4))[0, 0]) ``` </p> </details> <details><summary>This commit</summary><p> ``` Batch : 1, Channel : 1 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0001, 0.000078, 55.73us 0.0001, 0.000079, 117.51us 0.0000, 0.000003, 379.60us 0.0000, 0.000046, 1046.21us 0.0001, 0.000139, 3897.17us (3, 3, 3) 0.0021, 0.000076, 22.04us 0.0031, 0.000115, 21.47us 0.0022, 0.000080, 41.63us 0.0030, 0.000111, 100.59us 0.0025, 0.000091, 295.04us (5, 5, 5) 0.0103, 0.000083, 21.65us 0.0097, 0.000078, 21.37us 0.0103, 0.000083, 21.60us 0.0114, 0.000091, 25.69us 0.0107, 0.000085, 97.06us (7, 7, 7) 0.0312, 0.000091, 21.52us 0.0290, 0.000084, 21.61us 0.0311, 0.000091, 21.60us 0.0309, 0.000090, 21.44us 0.0334, 0.000097, 33.60us (9, 9, 9) 0.0646, 0.000089, 21.57us 0.0672, 0.000092, 21.89us 0.0662, 0.000091, 21.89us 0.0684, 0.000094, 27.64us 0.0660, 0.000091, 54.85us (11, 11, 11) 0.1251, 0.000094, 21.68us 0.1194, 0.000090, 21.70us 0.1202, 0.000090, 21.72us 0.1233, 0.000093, 22.25us 0.1229, 0.000092, 41.39us (13, 13, 13) 0.2038, 0.000093, 21.57us 0.2047, 0.000093, 21.58us 0.1964, 0.000089, 21.54us 0.2021, 0.000092, 21.94us 0.1989, 0.000091, 40.01us Batch : 1, Channel : 3 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0003, 0.000110, 55.74us 0.0003, 0.000093, 118.62us 0.0003, 0.000093, 382.12us 0.0001, 0.000040, 1052.33us 0.0003, 0.000114, 3917.90us (3, 3, 3) 0.0073, 0.000090, 21.84us 0.0075, 0.000093, 22.25us 0.0072, 0.000089, 41.78us 0.0070, 0.000087, 100.27us 0.0069, 0.000086, 293.96us (5, 5, 5) 0.0353, 0.000094, 22.57us 0.0325, 0.000087, 21.64us 0.0343, 0.000092, 22.63us 0.0338, 0.000090, 25.82us 0.0332, 0.000089, 97.16us (7, 7, 7) 0.0937, 0.000091, 22.50us 0.0910, 0.000088, 21.92us 0.0933, 0.000091, 21.99us 0.0948, 0.000092, 21.56us 0.0928, 0.000090, 34.17us (9, 9, 9) 0.1957, 0.000089, 21.68us 0.1984, 0.000091, 21.57us 0.2025, 0.000093, 22.10us 0.1986, 0.000091, 27.66us 0.2020, 0.000092, 55.32us (11, 11, 11) 0.3585, 0.000090, 21.75us 0.3684, 0.000092, 22.70us 0.3706, 0.000093, 21.67us 0.3752, 0.000094, 21.86us 0.3663, 0.000092, 41.22us (13, 13, 13) 0.5931, 0.000090, 21.67us 0.6056, 0.000092, 21.79us 0.6005, 0.000091, 21.79us 0.6112, 0.000093, 21.69us 0.6034, 0.000092, 40.02us Batch : 1, Channel : 54 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0051, 0.000095, 55.76us 0.0060, 0.000112, 118.60us 0.0036, 0.000067, 381.50us 0.0054, 0.000100, 1054.03us 0.0048, 0.000089, 4888.68us (3, 3, 3) 0.1332, 0.000091, 21.66us 0.1344, 0.000092, 22.62us 0.1354, 0.000093, 45.72us 0.1364, 0.000094, 106.63us 0.1324, 0.000091, 448.31us (5, 5, 5) 0.6221, 0.000092, 22.48us 0.6220, 0.000092, 21.71us 0.6053, 0.000090, 27.65us 0.6137, 0.000091, 31.40us 0.6209, 0.000092, 172.78us (7, 7, 7) 1.6859, 0.000091, 22.42us 1.6972, 0.000092, 21.96us 1.6849, 0.000091, 23.14us 1.7012, 0.000092, 26.25us 1.6920, 0.000091, 75.58us (9, 9, 9) 3.5811, 0.000091, 21.73us 3.5746, 0.000091, 22.55us 3.6237, 0.000092, 27.66us 3.6046, 0.000092, 59.71us 3.6392, 0.000092, 168.15us (11, 11, 11) 6.5582, 0.000091, 22.05us 6.5746, 0.000091, 21.74us 6.5955, 0.000092, 32.91us 6.5644, 0.000091, 45.57us 6.5697, 0.000091, 114.01us (13, 13, 13) 10.6384, 0.000090, 21.81us 10.8608, 0.000092, 21.79us 10.8375, 0.000091, 37.01us 10.8662, 0.000092, 51.80us 10.8593, 0.000092, 123.19us Batch : 1, Channel : 16 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0015, 0.000093, 55.75us 0.0012, 0.000075, 118.10us 0.0013, 0.000079, 379.25us 0.0012, 0.000075, 1047.21us 0.0013, 0.000079, 4451.57us (3, 3, 3) 0.0407, 0.000094, 21.82us 0.0395, 0.000091, 21.69us 0.0385, 0.000089, 42.07us 0.0397, 0.000092, 100.33us 0.0384, 0.000089, 363.31us (5, 5, 5) 0.1858, 0.000093, 21.76us 0.1799, 0.000090, 21.63us 0.1834, 0.000092, 21.76us 0.1890, 0.000095, 26.04us 0.1814, 0.000091, 135.32us (7, 7, 7) 0.4937, 0.000090, 21.65us 0.5076, 0.000092, 21.69us 0.5001, 0.000091, 22.31us 0.4988, 0.000091, 21.59us 0.5123, 0.000093, 50.03us (9, 9, 9) 1.0678, 0.000092, 21.73us 1.0752, 0.000092, 21.75us 1.0673, 0.000091, 21.75us 1.0649, 0.000091, 30.01us 1.0786, 0.000092, 70.92us (11, 11, 11) 1.9591, 0.000092, 21.57us 1.9522, 0.000092, 21.60us 1.9566, 0.000092, 21.73us 1.9475, 0.000091, 23.46us 1.9323, 0.000091, 55.02us (13, 13, 13) 3.1784, 0.000090, 22.02us 3.2165, 0.000092, 21.95us 3.1969, 0.000091, 21.92us 3.2061, 0.000091, 24.40us 3.2578, 0.000093, 56.00us Batch : 8, Channel : 1 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0010, 0.000122, 55.74us 0.0009, 0.000114, 118.82us 0.0006, 0.000074, 379.80us 0.0009, 0.000107, 1047.31us 0.0008, 0.000102, 3900.36us (3, 3, 3) 0.0219, 0.000101, 21.57us 0.0200, 0.000093, 21.61us 0.0194, 0.000090, 41.74us 0.0208, 0.000096, 99.91us 0.0212, 0.000098, 293.03us (5, 5, 5) 0.0906, 0.000091, 21.46us 0.0911, 0.000091, 21.60us 0.0934, 0.000093, 21.93us 0.0927, 0.000093, 25.74us 0.0913, 0.000091, 96.85us (7, 7, 7) 0.2530, 0.000092, 22.53us 0.2526, 0.000092, 22.46us 0.2558, 0.000093, 22.03us 0.2542, 0.000093, 22.29us 0.2475, 0.000090, 34.44us (9, 9, 9) 0.5305, 0.000091, 22.34us 0.5368, 0.000092, 22.42us 0.5265, 0.000090, 21.74us 0.5370, 0.000092, 27.81us 0.5416, 0.000093, 55.65us (11, 11, 11) 0.9887, 0.000093, 21.80us 0.9660, 0.000091, 21.61us 0.9793, 0.000092, 22.11us 0.9719, 0.000091, 21.80us 0.9650, 0.000091, 43.90us (13, 13, 13) 1.6024, 0.000091, 21.87us 1.6198, 0.000092, 22.65us 1.6242, 0.000092, 21.73us 1.6236, 0.000092, 22.59us 1.6025, 0.000091, 42.77us Batch : 8, Channel : 16 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0113, 0.000088, 56.66us 0.0117, 0.000091, 119.57us 0.0130, 0.000102, 389.57us 0.0110, 0.000086, 1433.78us 0.0119, 0.000093, 5217.61us (3, 3, 3) 0.3209, 0.000093, 21.54us 0.3184, 0.000092, 22.87us 0.3115, 0.000090, 51.00us 0.3171, 0.000092, 164.17us 0.3182, 0.000092, 500.60us (5, 5, 5) 1.4391, 0.000090, 22.39us 1.4577, 0.000091, 21.69us 1.4601, 0.000091, 53.87us 1.4626, 0.000091, 93.65us 1.4567, 0.000091, 370.11us (7, 7, 7) 4.0501, 0.000092, 22.34us 4.0230, 0.000092, 31.45us 4.0381, 0.000092, 45.19us 4.0171, 0.000091, 65.35us 4.0108, 0.000091, 164.76us (9, 9, 9) 8.5360, 0.000091, 22.80us 8.5456, 0.000092, 27.24us 8.5461, 0.000092, 50.23us 8.5677, 0.000092, 117.63us 8.5645, 0.000092, 270.46us (11, 11, 11) 15.5521, 0.000091, 26.56us 15.5826, 0.000091, 32.81us 15.6014, 0.000092, 63.82us 15.5620, 0.000091, 96.87us 15.5722, 0.000091, 220.24us (13, 13, 13) 25.4146, 0.000090, 32.91us 25.7898, 0.000092, 38.48us 25.6698, 0.000091, 72.02us 25.8193, 0.000092, 121.73us 25.7718, 0.000092, 249.71us Batch : 8, Channel : 54 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0377, 0.000087, 109.07us 0.0405, 0.000094, 233.17us 0.0392, 0.000091, 998.97us 0.0393, 0.000091, 2960.68us 0.0408, 0.000094, 11879.53us (3, 3, 3) 1.0660, 0.000091, 25.68us 1.0761, 0.000092, 64.12us 1.0725, 0.000092, 182.50us 1.0801, 0.000093, 505.82us 1.0736, 0.000092, 1650.21us (5, 5, 5) 4.9587, 0.000092, 50.84us 4.9336, 0.000091, 47.38us 4.9696, 0.000092, 158.49us 4.9347, 0.000091, 237.39us 4.9303, 0.000091, 965.13us (7, 7, 7) 13.5409, 0.000091, 45.60us 13.5736, 0.000092, 87.45us 13.5012, 0.000091, 141.63us 13.6111, 0.000092, 181.51us 13.5296, 0.000091, 469.77us (9, 9, 9) 28.7817, 0.000091, 58.01us 28.7969, 0.000091, 77.61us 28.8761, 0.000092, 159.33us 28.8786, 0.000092, 334.47us 28.8093, 0.000091, 786.72us (11, 11, 11) 52.4453, 0.000091, 78.19us 52.7265, 0.000092, 95.12us 52.7322, 0.000092, 200.38us 52.6342, 0.000092, 282.41us 52.6467, 0.000092, 652.54us (13, 13, 13) 85.7411, 0.000090, 98.85us 86.7183, 0.000091, 115.28us 86.8545, 0.000092, 232.34us 86.9997, 0.000092, 367.32us 86.9083, 0.000092, 757.73us Float half Discepancies 0.03963914513587952 9.175728040712852e-05 432 Sum : tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16) ``` </p> </details> <details><summary>1.8.0</summary><p> ``` Batch : 1, Channel : 1 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0023, 0.002275, 74.35us 0.0040, 0.003985, 159.73us 0.3740, 0.374021, 546.59us 0.4587, 0.458663, 1543.16us 0.4906, 0.490637, 5945.97us (3, 3, 3) 0.0100, 0.000370, 20.37us 0.0230, 0.000852, 22.12us 0.0309, 0.001143, 54.75us 0.0520, 0.001926, 129.78us 7.1219, 0.263775, 377.11us (5, 5, 5) 0.0441, 0.000352, 20.06us 0.0394, 0.000316, 20.50us 0.0759, 0.000607, 26.43us 0.1499, 0.001199, 32.01us 0.2707, 0.002166, 128.15us (7, 7, 7) 0.0791, 0.000231, 20.10us 0.1002, 0.000292, 20.56us 0.1812, 0.000528, 20.48us 0.2424, 0.000707, 20.83us 0.4994, 0.001456, 43.97us (9, 9, 9) 0.1122, 0.000154, 20.55us 0.1778, 0.000244, 20.44us 0.2572, 0.000353, 20.15us 0.4149, 0.000569, 35.64us 0.7208, 0.000989, 68.46us (11, 11, 11) 0.2044, 0.000154, 20.47us 0.2647, 0.000199, 20.62us 0.3867, 0.000291, 20.61us 0.6059, 0.000455, 23.54us 1.0902, 0.000819, 53.32us (13, 13, 13) 0.3094, 0.000141, 20.53us 0.3843, 0.000175, 20.60us 0.5756, 0.000262, 20.80us 0.8598, 0.000391, 24.52us 1.4853, 0.000676, 47.70us Batch : 1, Channel : 3 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0054, 0.001801, 74.36us 0.0108, 0.003614, 158.94us 1.1183, 0.372768, 547.67us 1.3782, 0.459387, 1545.27us 1.4685, 0.489505, 5949.17us (3, 3, 3) 0.0308, 0.000380, 20.14us 0.0502, 0.000619, 22.11us 0.1210, 0.001493, 54.80us 0.1900, 0.002345, 130.47us 21.3483, 0.263560, 375.68us (5, 5, 5) 0.1179, 0.000314, 20.68us 0.1326, 0.000354, 20.53us 0.2662, 0.000710, 26.51us 0.4116, 0.001098, 31.85us 0.8369, 0.002232, 128.19us (7, 7, 7) 0.2335, 0.000227, 20.40us 0.3057, 0.000297, 20.43us 0.4954, 0.000481, 20.31us 0.7339, 0.000713, 20.74us 1.4208, 0.001381, 44.55us (9, 9, 9) 0.3326, 0.000152, 20.63us 0.5353, 0.000245, 20.42us 0.8025, 0.000367, 20.13us 1.2693, 0.000580, 35.64us 2.2096, 0.001010, 68.88us (11, 11, 11) 0.6121, 0.000153, 20.59us 0.8086, 0.000202, 20.42us 1.1700, 0.000293, 20.71us 1.8170, 0.000455, 23.54us 3.2117, 0.000804, 53.36us (13, 13, 13) 0.9165, 0.000139, 20.51us 1.1395, 0.000173, 20.56us 1.7343, 0.000263, 20.80us 2.5868, 0.000392, 24.59us 4.5823, 0.000695, 47.77us Batch : 1, Channel : 54 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.1092, 0.002023, 75.45us 0.1709, 0.003165, 160.44us 20.2452, 0.374911, 548.61us 24.7990, 0.459240, 1550.34us 26.4494, 0.489804, 6957.79us (3, 3, 3) 0.5352, 0.000367, 20.58us 1.0281, 0.000705, 24.14us 2.0150, 0.001382, 59.12us 3.3069, 0.002268, 138.23us 384.5216, 0.263732, 529.71us (5, 5, 5) 2.0739, 0.000307, 20.60us 2.5199, 0.000373, 20.44us 4.6916, 0.000695, 33.89us 7.9482, 0.001178, 37.74us 14.2553, 0.002112, 200.54us (7, 7, 7) 4.2236, 0.000228, 20.61us 5.5605, 0.000300, 20.97us 9.0440, 0.000488, 26.40us 12.7847, 0.000690, 30.64us 25.3050, 0.001366, 88.05us (9, 9, 9) 6.0817, 0.000154, 20.63us 9.5416, 0.000242, 20.84us 14.2416, 0.000362, 32.47us 22.8452, 0.000580, 78.57us 40.3246, 0.001024, 194.50us (11, 11, 11) 11.1144, 0.000155, 20.56us 14.5581, 0.000203, 20.91us 20.8263, 0.000290, 38.07us 33.0004, 0.000459, 52.74us 57.3275, 0.000798, 137.19us (13, 13, 13) 16.5176, 0.000139, 21.26us 20.8089, 0.000175, 22.33us 31.3433, 0.000264, 42.93us 45.9733, 0.000388, 59.84us 82.8301, 0.000698, 138.42us Batch : 1, Channel : 16 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0274, 0.001715, 74.99us 0.0485, 0.003034, 159.92us 5.9925, 0.374529, 546.35us 7.3389, 0.458679, 1544.53us 7.8354, 0.489714, 6677.00us (3, 3, 3) 0.1560, 0.000361, 20.72us 0.3043, 0.000704, 22.37us 0.5838, 0.001352, 54.97us 1.0455, 0.002420, 130.57us 113.9739, 0.263828, 463.43us (5, 5, 5) 0.6121, 0.000306, 20.12us 0.7247, 0.000362, 20.73us 1.3740, 0.000687, 26.59us 2.3794, 0.001190, 32.12us 4.1929, 0.002096, 165.81us (7, 7, 7) 1.2389, 0.000226, 20.59us 1.6311, 0.000297, 20.53us 2.6732, 0.000487, 20.37us 3.7501, 0.000683, 20.71us 7.4575, 0.001359, 59.16us (9, 9, 9) 1.7983, 0.000154, 20.64us 2.8075, 0.000241, 20.59us 4.2165, 0.000361, 20.38us 6.7153, 0.000576, 38.29us 12.0530, 0.001033, 86.33us (11, 11, 11) 3.3326, 0.000156, 20.56us 4.3061, 0.000202, 20.67us 6.2235, 0.000292, 20.47us 9.8009, 0.000460, 27.41us 16.9994, 0.000798, 68.49us (13, 13, 13) 4.9016, 0.000139, 20.63us 6.1261, 0.000174, 20.65us 9.2106, 0.000262, 20.93us 13.5843, 0.000386, 27.95us 24.6476, 0.000701, 64.88us Batch : 8, Channel : 1 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.0170, 0.002122, 74.99us 0.0316, 0.003946, 160.66us 3.0013, 0.375158, 546.94us 3.6780, 0.459753, 1544.58us 3.9197, 0.489966, 5948.43us (3, 3, 3) 0.0821, 0.000380, 20.27us 0.1559, 0.000722, 22.29us 0.3133, 0.001450, 54.72us 0.5100, 0.002361, 130.12us 57.0481, 0.264111, 376.71us (5, 5, 5) 0.3075, 0.000307, 20.57us 0.3680, 0.000368, 20.69us 0.6786, 0.000679, 26.61us 1.1744, 0.001174, 31.77us 2.0654, 0.002065, 128.31us (7, 7, 7) 0.6512, 0.000237, 20.60us 0.8359, 0.000305, 20.50us 1.3712, 0.000500, 20.75us 1.9472, 0.000710, 20.92us 3.7586, 0.001370, 44.59us (9, 9, 9) 0.9138, 0.000157, 20.43us 1.4198, 0.000243, 20.58us 2.1018, 0.000360, 20.52us 3.3691, 0.000578, 35.90us 5.9491, 0.001020, 69.16us (11, 11, 11) 1.6606, 0.000156, 20.63us 2.1599, 0.000203, 20.57us 3.1240, 0.000293, 20.98us 4.8874, 0.000459, 24.65us 8.4780, 0.000796, 56.47us (13, 13, 13) 2.4987, 0.000142, 20.71us 3.0667, 0.000174, 20.45us 4.6387, 0.000264, 20.76us 6.8187, 0.000388, 25.95us 12.2077, 0.000695, 50.46us Batch : 8, Channel : 16 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.2635, 0.002059, 75.66us 0.4030, 0.003149, 161.78us 48.0296, 0.375231, 550.46us 58.7787, 0.459209, 1902.41us 62.6966, 0.489817, 7817.48us (3, 3, 3) 1.2271, 0.000355, 20.72us 2.4185, 0.000700, 26.44us 4.6933, 0.001358, 64.66us 7.7016, 0.002228, 192.69us 912.0736, 0.263910, 593.69us (5, 5, 5) 4.8716, 0.000304, 24.75us 5.8624, 0.000366, 21.39us 11.0705, 0.000692, 66.94us 18.9280, 0.001183, 104.93us 34.0512, 0.002128, 441.81us (7, 7, 7) 10.1713, 0.000232, 20.98us 13.2273, 0.000301, 36.26us 21.5426, 0.000491, 52.18us 30.1910, 0.000688, 72.94us 59.8381, 0.001363, 191.52us (9, 9, 9) 14.4542, 0.000155, 23.85us 22.6579, 0.000243, 30.59us 33.8839, 0.000363, 57.40us 54.3563, 0.000583, 142.53us 95.8123, 0.001027, 309.24us (11, 11, 11) 26.3348, 0.000155, 30.07us 34.3043, 0.000201, 37.01us 49.8093, 0.000292, 74.04us 78.3720, 0.000460, 110.53us 136.5404, 0.000801, 264.14us (13, 13, 13) 39.3550, 0.000140, 37.38us 49.3207, 0.000175, 43.51us 74.1139, 0.000264, 83.70us 108.7627, 0.000387, 136.09us 196.5412, 0.000699, 280.16us Batch : 8, Channel : 54 / (diff, diff / numel, time) k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112) (1, 1, 1) 0.8467, 0.001960, 147.36us 1.3993, 0.003239, 314.95us 162.0182, 0.375042, 1327.22us 198.3226, 0.459080, 3921.79us 211.6123, 0.489843, 15646.94us (3, 3, 3) 4.3146, 0.000370, 29.23us 8.1125, 0.000696, 74.94us 15.8886, 0.001362, 223.69us 26.2404, 0.002250, 601.33us 3076.5354, 0.263763, 1974.06us (5, 5, 5) 16.5032, 0.000306, 58.79us 19.6887, 0.000365, 53.79us 37.2731, 0.000690, 192.34us 63.3076, 0.001172, 270.01us 114.8880, 0.002128, 1148.56us (7, 7, 7) 34.0802, 0.000230, 51.12us 44.4087, 0.000300, 100.93us 72.4613, 0.000489, 161.48us 101.9317, 0.000688, 202.91us 201.8955, 0.001363, 545.33us (9, 9, 9) 48.8179, 0.000155, 65.78us 76.3465, 0.000242, 87.48us 114.0228, 0.000362, 179.11us 182.9805, 0.000581, 403.66us 322.7040, 0.001025, 894.86us (11, 11, 11) 88.9993, 0.000155, 88.69us 116.4213, 0.000202, 107.55us 168.3363, 0.000293, 228.71us 264.2232, 0.000460, 322.84us 459.1324, 0.000799, 784.25us (13, 13, 13) 132.7447, 0.000140, 112.91us 165.4525, 0.000174, 131.08us 249.7127, 0.000263, 266.43us 367.0824, 0.000387, 410.17us 663.1367, 0.000699, 847.87us Float half Discepancies 198.37625122070312 0.4592042852331091 432 Sum : tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16) ``` </p> </details> ngimel malfet anjali411 Pull Request resolved: https://github.com/pytorch/pytorch/pull/53607 Reviewed By: mruberry Differential Revision: D27652337 Pulled By: ngimel fbshipit-source-id: 6439c0cafe6ca3f761a3f5d058050a55e9a0abd8
Author
Parents
Loading