pytorch
2c4a6458 - fix mkldnn_add in-place behavior (#51687)

Commit
3 years ago
fix mkldnn_add in-place behavior (#51687) Summary: There are the following two patterns to call add in-pace. ```python torch.add(a, b, out=a) # (1) a in-placed torch.add(a, b, out=b) # (2) b in-placed ``` If a and b are mkldnn Tensor, the value is different from expected in case (2). **Sample code to reproduce the behavior:** ```python import torch torch.manual_seed(4) a = torch.randn(4, 4) b = torch.randn(4, 4) b.fill_(1.0) a_mkl = a.to_mkldnn() b_mkl = b.to_mkldnn() torch.add(b, a, alpha=1.0, out=a) torch.add(b_mkl, a_mkl, alpha=1.0, out=a_mkl) print(a) print(a_mkl) ``` **Results:** Actual: ```python tensor([[ 0.0586, 2.2632, 0.8162, 1.1505], [ 1.1075, 0.7220, -1.6021, 1.6245], [ 0.1316, 0.7949, 1.3976, 1.6699], [ 0.9463, 1.0467, -0.7671, -1.1205]]) tensor([[2., 2., 2., 2.], [2., 2., 2., 2.], [2., 2., 2., 2.], [2., 2., 2., 2.]], layout=torch._mkldnn) ``` Expected: ```python tensor([[ 0.0586, 2.2632, 0.8162, 1.1505], [ 1.1075, 0.7220, -1.6021, 1.6245], [ 0.1316, 0.7949, 1.3976, 1.6699], [ 0.9463, 1.0467, -0.7671, -1.1205]]) tensor([[ 0.0586, 2.2632, 0.8162, 1.1505], [ 1.1075, 0.7220, -1.6021, 1.6245], [ 0.1316, 0.7949, 1.3976, 1.6699], [ 0.9463, 1.0467, -0.7671, -1.1205]], layout=torch._mkldnn) ``` This is because `dnnl::sum` called in `mkldnn_add` has the following specifications: [oneDNN doc : Sum](https://oneapi-src.github.io/oneDNN/dev_guide_sum.html) > The sum primitive supports in-place operation, meaning that the src0 tensor can be used as both input and output. > In-place operation overwrites the original data. Using in-place operation requires the memory footprint of the > output tensor to be either bigger than or equal to the size of the dst memory descriptor used for primitive creation. but, case 2) are added to the first argument. So, we modified it so that a and b are swapped and passed to "sum" in case (2). **Environment** ・CPU : Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz ・build USE_MKLDNN=1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/51687 Reviewed By: jbschlosser Differential Revision: D27062172 Pulled By: VitalyFedyunin fbshipit-source-id: bf76d36f9fdb1b4337d71d87bcdbaf4edb11f12f
Author
Parents
Loading