Add AccumulateType in AdaptiveAveragePooling3d.cu (#53607)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/52719
- Changed the type(`scalar_t`) of intermediate results to `at::acc_type<scalar_t, true>`
This issue occurs by decimal precision of the half precision.
Follows test cases of upper issue, The value range of input tensors are [0, 1] because init by `rand`.
And when the kernel size 1, summations all target values and divide numel of kernel
https://github.com/pytorch/pytorch/blob/34d9278c1913d83db47a25f0cac71b69ab877c84/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu#L94-L95
When adding [0, 1] values, if `sum` more than 2048 then not changed values. ( Even if the value is small, the mored exact value is added, but there are still precision issues.)
(https://en.wikipedia.org/wiki/Half-precision_floating-point_format)
Benchmarks
- In V100 32GB, Driver : 450.80, cuda 10.1
- faster than prev
<details><summary>Script</summary><p>
```import torch
from torch.utils.benchmark import Timer
torch.manual_seed(0)
kernel_sizes = [1, 3, 5, 7, 9, 11, 13]
shapes = [(12, 12, 12), (16, 16, 16), (16, 32, 32), (16, 56, 56), (16, 112, 112)]
def run(batch, channel):
print(f"Batch : {batch}, Channel : {channel} / (diff, diff / numel, time)")
head = "\t".join(f"{str(s):30s}" for s in ["k \ shape"] + shapes)
print(head)
for kernel_size in kernel_sizes:
kernel_size = (kernel_size, kernel_size, kernel_size)
pool = torch.nn.AdaptiveAvgPool3d(kernel_size)
print(f"{str(kernel_size):30s}", end="\t")
for shape in shapes:
x_half = torch.rand([batch, channel, *shape], dtype=torch.half, device="cuda")
x_float = x_half.float()
y_half = pool(x_half)
y_float = pool(x_float)
timer = Timer("pool(x_half)", globals={"pool": pool, "x_half": x_half})
measurement = timer.blocked_autorange(min_run_time=5)
diff = (y_float - y_half).abs().sum().item()
diff = f"{diff:.4f}, {diff / y_half.numel():.6f}, {measurement.median * 1e6 :3.2f}us"
print(f"{diff:30s}", end="\t")
print("")
run(1, 1)
run(1, 3)
run(1, 54)
run(1, 16)
run(8, 1)
run(8, 16)
run(8, 54)
import torch
m = torch.nn.AdaptiveAvgPool3d((1,1,1))
inputs = torch.rand([8,54,16,56,56])
inputs = inputs.cuda()
inputs_2 = inputs.half()
print("Float")
out = m(inputs).float()
print("half")
out2 = m(inputs_2).float()
print('Discepancies', torch.sum(torch.abs(out2- out)).item(), torch.sum(torch.abs(out2- out)).item() / out.numel() , out.numel())
print("Sum : ", torch.sum(inputs, dim=(2,3,4))[0, 0], torch.sum(inputs_2, dim=(2,3,4))[0, 0])
```
</p>
</details>
<details><summary>This commit</summary><p>
```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0001, 0.000078, 55.73us 0.0001, 0.000079, 117.51us 0.0000, 0.000003, 379.60us 0.0000, 0.000046, 1046.21us 0.0001, 0.000139, 3897.17us
(3, 3, 3) 0.0021, 0.000076, 22.04us 0.0031, 0.000115, 21.47us 0.0022, 0.000080, 41.63us 0.0030, 0.000111, 100.59us 0.0025, 0.000091, 295.04us
(5, 5, 5) 0.0103, 0.000083, 21.65us 0.0097, 0.000078, 21.37us 0.0103, 0.000083, 21.60us 0.0114, 0.000091, 25.69us 0.0107, 0.000085, 97.06us
(7, 7, 7) 0.0312, 0.000091, 21.52us 0.0290, 0.000084, 21.61us 0.0311, 0.000091, 21.60us 0.0309, 0.000090, 21.44us 0.0334, 0.000097, 33.60us
(9, 9, 9) 0.0646, 0.000089, 21.57us 0.0672, 0.000092, 21.89us 0.0662, 0.000091, 21.89us 0.0684, 0.000094, 27.64us 0.0660, 0.000091, 54.85us
(11, 11, 11) 0.1251, 0.000094, 21.68us 0.1194, 0.000090, 21.70us 0.1202, 0.000090, 21.72us 0.1233, 0.000093, 22.25us 0.1229, 0.000092, 41.39us
(13, 13, 13) 0.2038, 0.000093, 21.57us 0.2047, 0.000093, 21.58us 0.1964, 0.000089, 21.54us 0.2021, 0.000092, 21.94us 0.1989, 0.000091, 40.01us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0003, 0.000110, 55.74us 0.0003, 0.000093, 118.62us 0.0003, 0.000093, 382.12us 0.0001, 0.000040, 1052.33us 0.0003, 0.000114, 3917.90us
(3, 3, 3) 0.0073, 0.000090, 21.84us 0.0075, 0.000093, 22.25us 0.0072, 0.000089, 41.78us 0.0070, 0.000087, 100.27us 0.0069, 0.000086, 293.96us
(5, 5, 5) 0.0353, 0.000094, 22.57us 0.0325, 0.000087, 21.64us 0.0343, 0.000092, 22.63us 0.0338, 0.000090, 25.82us 0.0332, 0.000089, 97.16us
(7, 7, 7) 0.0937, 0.000091, 22.50us 0.0910, 0.000088, 21.92us 0.0933, 0.000091, 21.99us 0.0948, 0.000092, 21.56us 0.0928, 0.000090, 34.17us
(9, 9, 9) 0.1957, 0.000089, 21.68us 0.1984, 0.000091, 21.57us 0.2025, 0.000093, 22.10us 0.1986, 0.000091, 27.66us 0.2020, 0.000092, 55.32us
(11, 11, 11) 0.3585, 0.000090, 21.75us 0.3684, 0.000092, 22.70us 0.3706, 0.000093, 21.67us 0.3752, 0.000094, 21.86us 0.3663, 0.000092, 41.22us
(13, 13, 13) 0.5931, 0.000090, 21.67us 0.6056, 0.000092, 21.79us 0.6005, 0.000091, 21.79us 0.6112, 0.000093, 21.69us 0.6034, 0.000092, 40.02us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0051, 0.000095, 55.76us 0.0060, 0.000112, 118.60us 0.0036, 0.000067, 381.50us 0.0054, 0.000100, 1054.03us 0.0048, 0.000089, 4888.68us
(3, 3, 3) 0.1332, 0.000091, 21.66us 0.1344, 0.000092, 22.62us 0.1354, 0.000093, 45.72us 0.1364, 0.000094, 106.63us 0.1324, 0.000091, 448.31us
(5, 5, 5) 0.6221, 0.000092, 22.48us 0.6220, 0.000092, 21.71us 0.6053, 0.000090, 27.65us 0.6137, 0.000091, 31.40us 0.6209, 0.000092, 172.78us
(7, 7, 7) 1.6859, 0.000091, 22.42us 1.6972, 0.000092, 21.96us 1.6849, 0.000091, 23.14us 1.7012, 0.000092, 26.25us 1.6920, 0.000091, 75.58us
(9, 9, 9) 3.5811, 0.000091, 21.73us 3.5746, 0.000091, 22.55us 3.6237, 0.000092, 27.66us 3.6046, 0.000092, 59.71us 3.6392, 0.000092, 168.15us
(11, 11, 11) 6.5582, 0.000091, 22.05us 6.5746, 0.000091, 21.74us 6.5955, 0.000092, 32.91us 6.5644, 0.000091, 45.57us 6.5697, 0.000091, 114.01us
(13, 13, 13) 10.6384, 0.000090, 21.81us 10.8608, 0.000092, 21.79us 10.8375, 0.000091, 37.01us 10.8662, 0.000092, 51.80us 10.8593, 0.000092, 123.19us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0015, 0.000093, 55.75us 0.0012, 0.000075, 118.10us 0.0013, 0.000079, 379.25us 0.0012, 0.000075, 1047.21us 0.0013, 0.000079, 4451.57us
(3, 3, 3) 0.0407, 0.000094, 21.82us 0.0395, 0.000091, 21.69us 0.0385, 0.000089, 42.07us 0.0397, 0.000092, 100.33us 0.0384, 0.000089, 363.31us
(5, 5, 5) 0.1858, 0.000093, 21.76us 0.1799, 0.000090, 21.63us 0.1834, 0.000092, 21.76us 0.1890, 0.000095, 26.04us 0.1814, 0.000091, 135.32us
(7, 7, 7) 0.4937, 0.000090, 21.65us 0.5076, 0.000092, 21.69us 0.5001, 0.000091, 22.31us 0.4988, 0.000091, 21.59us 0.5123, 0.000093, 50.03us
(9, 9, 9) 1.0678, 0.000092, 21.73us 1.0752, 0.000092, 21.75us 1.0673, 0.000091, 21.75us 1.0649, 0.000091, 30.01us 1.0786, 0.000092, 70.92us
(11, 11, 11) 1.9591, 0.000092, 21.57us 1.9522, 0.000092, 21.60us 1.9566, 0.000092, 21.73us 1.9475, 0.000091, 23.46us 1.9323, 0.000091, 55.02us
(13, 13, 13) 3.1784, 0.000090, 22.02us 3.2165, 0.000092, 21.95us 3.1969, 0.000091, 21.92us 3.2061, 0.000091, 24.40us 3.2578, 0.000093, 56.00us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0010, 0.000122, 55.74us 0.0009, 0.000114, 118.82us 0.0006, 0.000074, 379.80us 0.0009, 0.000107, 1047.31us 0.0008, 0.000102, 3900.36us
(3, 3, 3) 0.0219, 0.000101, 21.57us 0.0200, 0.000093, 21.61us 0.0194, 0.000090, 41.74us 0.0208, 0.000096, 99.91us 0.0212, 0.000098, 293.03us
(5, 5, 5) 0.0906, 0.000091, 21.46us 0.0911, 0.000091, 21.60us 0.0934, 0.000093, 21.93us 0.0927, 0.000093, 25.74us 0.0913, 0.000091, 96.85us
(7, 7, 7) 0.2530, 0.000092, 22.53us 0.2526, 0.000092, 22.46us 0.2558, 0.000093, 22.03us 0.2542, 0.000093, 22.29us 0.2475, 0.000090, 34.44us
(9, 9, 9) 0.5305, 0.000091, 22.34us 0.5368, 0.000092, 22.42us 0.5265, 0.000090, 21.74us 0.5370, 0.000092, 27.81us 0.5416, 0.000093, 55.65us
(11, 11, 11) 0.9887, 0.000093, 21.80us 0.9660, 0.000091, 21.61us 0.9793, 0.000092, 22.11us 0.9719, 0.000091, 21.80us 0.9650, 0.000091, 43.90us
(13, 13, 13) 1.6024, 0.000091, 21.87us 1.6198, 0.000092, 22.65us 1.6242, 0.000092, 21.73us 1.6236, 0.000092, 22.59us 1.6025, 0.000091, 42.77us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0113, 0.000088, 56.66us 0.0117, 0.000091, 119.57us 0.0130, 0.000102, 389.57us 0.0110, 0.000086, 1433.78us 0.0119, 0.000093, 5217.61us
(3, 3, 3) 0.3209, 0.000093, 21.54us 0.3184, 0.000092, 22.87us 0.3115, 0.000090, 51.00us 0.3171, 0.000092, 164.17us 0.3182, 0.000092, 500.60us
(5, 5, 5) 1.4391, 0.000090, 22.39us 1.4577, 0.000091, 21.69us 1.4601, 0.000091, 53.87us 1.4626, 0.000091, 93.65us 1.4567, 0.000091, 370.11us
(7, 7, 7) 4.0501, 0.000092, 22.34us 4.0230, 0.000092, 31.45us 4.0381, 0.000092, 45.19us 4.0171, 0.000091, 65.35us 4.0108, 0.000091, 164.76us
(9, 9, 9) 8.5360, 0.000091, 22.80us 8.5456, 0.000092, 27.24us 8.5461, 0.000092, 50.23us 8.5677, 0.000092, 117.63us 8.5645, 0.000092, 270.46us
(11, 11, 11) 15.5521, 0.000091, 26.56us 15.5826, 0.000091, 32.81us 15.6014, 0.000092, 63.82us 15.5620, 0.000091, 96.87us 15.5722, 0.000091, 220.24us
(13, 13, 13) 25.4146, 0.000090, 32.91us 25.7898, 0.000092, 38.48us 25.6698, 0.000091, 72.02us 25.8193, 0.000092, 121.73us 25.7718, 0.000092, 249.71us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0377, 0.000087, 109.07us 0.0405, 0.000094, 233.17us 0.0392, 0.000091, 998.97us 0.0393, 0.000091, 2960.68us 0.0408, 0.000094, 11879.53us
(3, 3, 3) 1.0660, 0.000091, 25.68us 1.0761, 0.000092, 64.12us 1.0725, 0.000092, 182.50us 1.0801, 0.000093, 505.82us 1.0736, 0.000092, 1650.21us
(5, 5, 5) 4.9587, 0.000092, 50.84us 4.9336, 0.000091, 47.38us 4.9696, 0.000092, 158.49us 4.9347, 0.000091, 237.39us 4.9303, 0.000091, 965.13us
(7, 7, 7) 13.5409, 0.000091, 45.60us 13.5736, 0.000092, 87.45us 13.5012, 0.000091, 141.63us 13.6111, 0.000092, 181.51us 13.5296, 0.000091, 469.77us
(9, 9, 9) 28.7817, 0.000091, 58.01us 28.7969, 0.000091, 77.61us 28.8761, 0.000092, 159.33us 28.8786, 0.000092, 334.47us 28.8093, 0.000091, 786.72us
(11, 11, 11) 52.4453, 0.000091, 78.19us 52.7265, 0.000092, 95.12us 52.7322, 0.000092, 200.38us 52.6342, 0.000092, 282.41us 52.6467, 0.000092, 652.54us
(13, 13, 13) 85.7411, 0.000090, 98.85us 86.7183, 0.000091, 115.28us 86.8545, 0.000092, 232.34us 86.9997, 0.000092, 367.32us 86.9083, 0.000092, 757.73us
Float
half
Discepancies 0.03963914513587952 9.175728040712852e-05 432
Sum : tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>
<details><summary>1.8.0</summary><p>
```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0023, 0.002275, 74.35us 0.0040, 0.003985, 159.73us 0.3740, 0.374021, 546.59us 0.4587, 0.458663, 1543.16us 0.4906, 0.490637, 5945.97us
(3, 3, 3) 0.0100, 0.000370, 20.37us 0.0230, 0.000852, 22.12us 0.0309, 0.001143, 54.75us 0.0520, 0.001926, 129.78us 7.1219, 0.263775, 377.11us
(5, 5, 5) 0.0441, 0.000352, 20.06us 0.0394, 0.000316, 20.50us 0.0759, 0.000607, 26.43us 0.1499, 0.001199, 32.01us 0.2707, 0.002166, 128.15us
(7, 7, 7) 0.0791, 0.000231, 20.10us 0.1002, 0.000292, 20.56us 0.1812, 0.000528, 20.48us 0.2424, 0.000707, 20.83us 0.4994, 0.001456, 43.97us
(9, 9, 9) 0.1122, 0.000154, 20.55us 0.1778, 0.000244, 20.44us 0.2572, 0.000353, 20.15us 0.4149, 0.000569, 35.64us 0.7208, 0.000989, 68.46us
(11, 11, 11) 0.2044, 0.000154, 20.47us 0.2647, 0.000199, 20.62us 0.3867, 0.000291, 20.61us 0.6059, 0.000455, 23.54us 1.0902, 0.000819, 53.32us
(13, 13, 13) 0.3094, 0.000141, 20.53us 0.3843, 0.000175, 20.60us 0.5756, 0.000262, 20.80us 0.8598, 0.000391, 24.52us 1.4853, 0.000676, 47.70us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0054, 0.001801, 74.36us 0.0108, 0.003614, 158.94us 1.1183, 0.372768, 547.67us 1.3782, 0.459387, 1545.27us 1.4685, 0.489505, 5949.17us
(3, 3, 3) 0.0308, 0.000380, 20.14us 0.0502, 0.000619, 22.11us 0.1210, 0.001493, 54.80us 0.1900, 0.002345, 130.47us 21.3483, 0.263560, 375.68us
(5, 5, 5) 0.1179, 0.000314, 20.68us 0.1326, 0.000354, 20.53us 0.2662, 0.000710, 26.51us 0.4116, 0.001098, 31.85us 0.8369, 0.002232, 128.19us
(7, 7, 7) 0.2335, 0.000227, 20.40us 0.3057, 0.000297, 20.43us 0.4954, 0.000481, 20.31us 0.7339, 0.000713, 20.74us 1.4208, 0.001381, 44.55us
(9, 9, 9) 0.3326, 0.000152, 20.63us 0.5353, 0.000245, 20.42us 0.8025, 0.000367, 20.13us 1.2693, 0.000580, 35.64us 2.2096, 0.001010, 68.88us
(11, 11, 11) 0.6121, 0.000153, 20.59us 0.8086, 0.000202, 20.42us 1.1700, 0.000293, 20.71us 1.8170, 0.000455, 23.54us 3.2117, 0.000804, 53.36us
(13, 13, 13) 0.9165, 0.000139, 20.51us 1.1395, 0.000173, 20.56us 1.7343, 0.000263, 20.80us 2.5868, 0.000392, 24.59us 4.5823, 0.000695, 47.77us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.1092, 0.002023, 75.45us 0.1709, 0.003165, 160.44us 20.2452, 0.374911, 548.61us 24.7990, 0.459240, 1550.34us 26.4494, 0.489804, 6957.79us
(3, 3, 3) 0.5352, 0.000367, 20.58us 1.0281, 0.000705, 24.14us 2.0150, 0.001382, 59.12us 3.3069, 0.002268, 138.23us 384.5216, 0.263732, 529.71us
(5, 5, 5) 2.0739, 0.000307, 20.60us 2.5199, 0.000373, 20.44us 4.6916, 0.000695, 33.89us 7.9482, 0.001178, 37.74us 14.2553, 0.002112, 200.54us
(7, 7, 7) 4.2236, 0.000228, 20.61us 5.5605, 0.000300, 20.97us 9.0440, 0.000488, 26.40us 12.7847, 0.000690, 30.64us 25.3050, 0.001366, 88.05us
(9, 9, 9) 6.0817, 0.000154, 20.63us 9.5416, 0.000242, 20.84us 14.2416, 0.000362, 32.47us 22.8452, 0.000580, 78.57us 40.3246, 0.001024, 194.50us
(11, 11, 11) 11.1144, 0.000155, 20.56us 14.5581, 0.000203, 20.91us 20.8263, 0.000290, 38.07us 33.0004, 0.000459, 52.74us 57.3275, 0.000798, 137.19us
(13, 13, 13) 16.5176, 0.000139, 21.26us 20.8089, 0.000175, 22.33us 31.3433, 0.000264, 42.93us 45.9733, 0.000388, 59.84us 82.8301, 0.000698, 138.42us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0274, 0.001715, 74.99us 0.0485, 0.003034, 159.92us 5.9925, 0.374529, 546.35us 7.3389, 0.458679, 1544.53us 7.8354, 0.489714, 6677.00us
(3, 3, 3) 0.1560, 0.000361, 20.72us 0.3043, 0.000704, 22.37us 0.5838, 0.001352, 54.97us 1.0455, 0.002420, 130.57us 113.9739, 0.263828, 463.43us
(5, 5, 5) 0.6121, 0.000306, 20.12us 0.7247, 0.000362, 20.73us 1.3740, 0.000687, 26.59us 2.3794, 0.001190, 32.12us 4.1929, 0.002096, 165.81us
(7, 7, 7) 1.2389, 0.000226, 20.59us 1.6311, 0.000297, 20.53us 2.6732, 0.000487, 20.37us 3.7501, 0.000683, 20.71us 7.4575, 0.001359, 59.16us
(9, 9, 9) 1.7983, 0.000154, 20.64us 2.8075, 0.000241, 20.59us 4.2165, 0.000361, 20.38us 6.7153, 0.000576, 38.29us 12.0530, 0.001033, 86.33us
(11, 11, 11) 3.3326, 0.000156, 20.56us 4.3061, 0.000202, 20.67us 6.2235, 0.000292, 20.47us 9.8009, 0.000460, 27.41us 16.9994, 0.000798, 68.49us
(13, 13, 13) 4.9016, 0.000139, 20.63us 6.1261, 0.000174, 20.65us 9.2106, 0.000262, 20.93us 13.5843, 0.000386, 27.95us 24.6476, 0.000701, 64.88us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0170, 0.002122, 74.99us 0.0316, 0.003946, 160.66us 3.0013, 0.375158, 546.94us 3.6780, 0.459753, 1544.58us 3.9197, 0.489966, 5948.43us
(3, 3, 3) 0.0821, 0.000380, 20.27us 0.1559, 0.000722, 22.29us 0.3133, 0.001450, 54.72us 0.5100, 0.002361, 130.12us 57.0481, 0.264111, 376.71us
(5, 5, 5) 0.3075, 0.000307, 20.57us 0.3680, 0.000368, 20.69us 0.6786, 0.000679, 26.61us 1.1744, 0.001174, 31.77us 2.0654, 0.002065, 128.31us
(7, 7, 7) 0.6512, 0.000237, 20.60us 0.8359, 0.000305, 20.50us 1.3712, 0.000500, 20.75us 1.9472, 0.000710, 20.92us 3.7586, 0.001370, 44.59us
(9, 9, 9) 0.9138, 0.000157, 20.43us 1.4198, 0.000243, 20.58us 2.1018, 0.000360, 20.52us 3.3691, 0.000578, 35.90us 5.9491, 0.001020, 69.16us
(11, 11, 11) 1.6606, 0.000156, 20.63us 2.1599, 0.000203, 20.57us 3.1240, 0.000293, 20.98us 4.8874, 0.000459, 24.65us 8.4780, 0.000796, 56.47us
(13, 13, 13) 2.4987, 0.000142, 20.71us 3.0667, 0.000174, 20.45us 4.6387, 0.000264, 20.76us 6.8187, 0.000388, 25.95us 12.2077, 0.000695, 50.46us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.2635, 0.002059, 75.66us 0.4030, 0.003149, 161.78us 48.0296, 0.375231, 550.46us 58.7787, 0.459209, 1902.41us 62.6966, 0.489817, 7817.48us
(3, 3, 3) 1.2271, 0.000355, 20.72us 2.4185, 0.000700, 26.44us 4.6933, 0.001358, 64.66us 7.7016, 0.002228, 192.69us 912.0736, 0.263910, 593.69us
(5, 5, 5) 4.8716, 0.000304, 24.75us 5.8624, 0.000366, 21.39us 11.0705, 0.000692, 66.94us 18.9280, 0.001183, 104.93us 34.0512, 0.002128, 441.81us
(7, 7, 7) 10.1713, 0.000232, 20.98us 13.2273, 0.000301, 36.26us 21.5426, 0.000491, 52.18us 30.1910, 0.000688, 72.94us 59.8381, 0.001363, 191.52us
(9, 9, 9) 14.4542, 0.000155, 23.85us 22.6579, 0.000243, 30.59us 33.8839, 0.000363, 57.40us 54.3563, 0.000583, 142.53us 95.8123, 0.001027, 309.24us
(11, 11, 11) 26.3348, 0.000155, 30.07us 34.3043, 0.000201, 37.01us 49.8093, 0.000292, 74.04us 78.3720, 0.000460, 110.53us 136.5404, 0.000801, 264.14us
(13, 13, 13) 39.3550, 0.000140, 37.38us 49.3207, 0.000175, 43.51us 74.1139, 0.000264, 83.70us 108.7627, 0.000387, 136.09us 196.5412, 0.000699, 280.16us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.8467, 0.001960, 147.36us 1.3993, 0.003239, 314.95us 162.0182, 0.375042, 1327.22us 198.3226, 0.459080, 3921.79us 211.6123, 0.489843, 15646.94us
(3, 3, 3) 4.3146, 0.000370, 29.23us 8.1125, 0.000696, 74.94us 15.8886, 0.001362, 223.69us 26.2404, 0.002250, 601.33us 3076.5354, 0.263763, 1974.06us
(5, 5, 5) 16.5032, 0.000306, 58.79us 19.6887, 0.000365, 53.79us 37.2731, 0.000690, 192.34us 63.3076, 0.001172, 270.01us 114.8880, 0.002128, 1148.56us
(7, 7, 7) 34.0802, 0.000230, 51.12us 44.4087, 0.000300, 100.93us 72.4613, 0.000489, 161.48us 101.9317, 0.000688, 202.91us 201.8955, 0.001363, 545.33us
(9, 9, 9) 48.8179, 0.000155, 65.78us 76.3465, 0.000242, 87.48us 114.0228, 0.000362, 179.11us 182.9805, 0.000581, 403.66us 322.7040, 0.001025, 894.86us
(11, 11, 11) 88.9993, 0.000155, 88.69us 116.4213, 0.000202, 107.55us 168.3363, 0.000293, 228.71us 264.2232, 0.000460, 322.84us 459.1324, 0.000799, 784.25us
(13, 13, 13) 132.7447, 0.000140, 112.91us 165.4525, 0.000174, 131.08us 249.7127, 0.000263, 266.43us 367.0824, 0.000387, 410.17us 663.1367, 0.000699, 847.87us
Float
half
Discepancies 198.37625122070312 0.4592042852331091 432
Sum : tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>
ngimel malfet anjali411
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53607
Reviewed By: mruberry
Differential Revision: D27652337
Pulled By: ngimel
fbshipit-source-id: 6439c0cafe6ca3f761a3f5d058050a55e9a0abd8