Add BFloat16 dtype support for oneDNN Graph JIT fuser (#85591)
## BFloat16 dtype support for faster inference with TorchScript using oneDNN Graph
Intel Xeon Cooper Lake platform & beyond support the `AVX512_BF16` ISA, which is essentially native BFloat16 support.
oneDNN Graph delivers high inference performance with BFloat16 on such machines.
While oneDNN Graph can still be used with BFloat16 on older machines that lack `avx512_bf16` ISA but support `avx512bw`, `avx512vl` & `avx512dq` ISAs, the BF16 performance on these older machines will be significantly poorer (probably even poorer than Float32), as they lack native BF16 support.
Currently, [AMP support for eager mode & JIT mode is divergent in PyTorch](https://github.com/pytorch/pytorch/issues/75956).
So, for using oneDNN Graph with BFloat16, eager-mode AMP should be leveraged by turning off AMP for JIT mode, using `torch._C._jit_set_autocast_mode(False)` in python code, so as to avoid conflicts.
Please use the following environment variable to view JIT logs -
`PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface"`
## Changes being made in this PR
1. This PR does NOT change the `oneDNN` commit or the `ideep` files. While the `ideep` commit is being updated, only files pertaining to oneDNN Graph are being updated. oneDNN Graph is being upgraded to version 0.5.2 (alpha patch release 2).
To put things into perspective, `ideep` is a git submodule of PyTorch. `oneDNN Graph` is a git submodule of `ideep` (`ideep/mkl-dnn`), and oneDNN is a git submodule of oneDNN Graph (`ideep/mkl-dnn/third_party/oneDNN`).
2. Unit-tests are being updated. We now use the [existing dtypes decorator](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_device_type.py#L123-L131).
3. Suggestions made by @eellison in the [FP32 PR](https://github.com/pytorch/pytorch/pull/68111#pullrequestreview-896719477) are being incorporated/addressed -
| Action-item | Status |
| :--- | ---: |
|checkInputCompatibility follow up | Fixed |
|the mayConvertScalarInputToTensor logic we can consider | Added type promotion code |
|fix up fixConvOptionalBias| The current approach seems correct |
|Use opinfo tests| using dtypes decorator. Will use `OpInfo` in a subsequent PR, if that'd be possible. Should we create a list of ops from opDB that are supported by oneDNN Graph, and add it to `common_methods_invocations.py`? |
|inferDevice torch_check call | not necessary now, perhaps, as only CPU is supported, for now? We'd add it by the beta release of oneDNN Graph, though, so that by then, users might be able to use other fusers with oneDNN Graph (NNC/TensorExpr are already compatible with the oneDNN Graph fuser). We can still add it, if you'd insist. |
|not checking shapes of input mkldnn tensor to llga guard | Those checks should not be present because oneDNN Graph may use blocked or channels-last layout, so those strides would be different. They're only skipped if an LLGA subgraph's output is input to another LLGA subgraph, which enables LLGA to choose an optimal layout between them. |
|fix test failures with respect to unsupported inputs | We'll address them with the upcoming release of oneDNN Graph beta version|
4. More PyTorch ops are being been mapped to oneDNN Graph
## Example of using oneDNN Graph with BFloat16
```python
# Assuming we have a model of the name 'model'
example_input = torch.rand(1, 3, 224, 224)
# enable oneDNN Graph
torch.jit.enable_onednn_fusion(True)
# Disable AMP for JIT
torch._C._jit_set_autocast_mode(False)
with torch.no_grad(), torch.cpu.amp.autocast():
model = torch.jit.trace(model, (example_input))
model = torch.jit.freeze(model)
# 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
model(example_input)
model(example_input)
# speedup would be observed in subsequent runs.
model(example_input)
```
## TorchBench based Benchmarks
**URL:** https://github.com/sanchitintel/benchmark/tree/onednn_graph_benchmark (instructions present at URL).
**Batch-size(s):** TorchBench-default for each model
**Baseline :** PyTorch JIT OFI FP32
**Machine:** Intel(R) Xeon(R) Platinum 8371HC (Cooper Lake)
**Sockets used**: 1
**Number of cores on one socket**: 26
Intel OpenMP & tcmalloc were preloaded
#### Benchmark results with single thread
| name | latency of PyTorch JIT OFI FP32 (s) | Latency of oneDNN Graph BF16 (s) | % change |
| :--- | ---: | ---: | ---: |
| test_eval[alexnet-cpu-jit] | 1.063851 | 0.509820 | -52.1% |
| test_eval[mnasnet1_0-cpu-jit] | 0.218435 | 0.107100 | -51.0% |
| test_eval[mobilenet_v2-cpu-jit] | 0.114467 | 0.058359 | -49.0% |
| test_eval[mobilenet_v3_large-cpu-jit] | 0.233873 | 0.117614 | -49.7% |
| test_eval[resnet18-cpu-jit] | 0.160584 | 0.075854 | -52.8% |
| test_eval[resnet50-cpu-jit] | 1.652846 | 0.713373 | -56.8% |
| test_eval[resnext50_32x4d-cpu-jit] | 0.471174 | 0.209431 | -55.6% |
|test_eval[shufflenet_v2_x1_0-cpu-jit] | 0.310306 | 0.167090 | -46.2% |
| test_eval[squeezenet1_1-cpu-jit] | 0.161247 | 0.045684 | -71.7% |
| test_eval[timm_efficientnet-cpu-jit] | 1.643772 | 0.800099 | -51.3% |
| test_eval[timm_regnet-cpu-jit] | 5.732272 | 2.333417 | -59.3% |
| test_eval[timm_resnest-cpu-jit] | 1.366464 | 0.715252 | -47.7% |
| test_eval[timm_vision_transformer-cpu-jit] | 0.508521 | 0.271598 | -46.6% |
| test_eval[timm_vovnet-cpu-jit] | 2.756692 | 1.125033 | -59.2% |
| test_eval[vgg16-cpu-jit] | 0.711533 | 0.312344 | -56.1% |
#### Benchmark results with 26 threads:
| name | latency of PyTorch JIT OFI FP32 (s) | Latency of oneDNN Graph BF16 (s) | % change |
| :--- | ---: | ---: | ---: |
| test_eval[alexnet-cpu-jit] | 0.062871 | 0.034198 | -45.6% |
| test_eval[mnasnet1_0-cpu-jit] | 0.022490 | 0.008172 | -63.7% |
| test_eval[mobilenet_v2-cpu-jit] | 0.012730 | 0.005866 | -53.9% |
| test_eval[mobilenet_v3_large-cpu-jit] | 0.025948 | 0.010346 | -60.1% |
| test_eval[resnet18-cpu-jit] | 0.011194 | 0.005726 | -48.9% |
| test_eval[resnet50-cpu-jit] | 0.124662 | 0.045599 | -63.4% |
| test_eval[resnext50_32x4d-cpu-jit] | 0.034737 | 0.015214 | -56.2% |
|test_eval[shufflenet_v2_x1_0-cpu-jit] | 0.028820 | 0.012517 | -56.6% |
| test_eval[squeezenet1_1-cpu-jit] | 0.012557 | 0.003876 | -69.1% |
| test_eval[timm_efficientnet-cpu-jit] | 0.203177 | 0.051879 | -74.5% |
| test_eval[timm_regnet-cpu-jit] | 0.452050 | 0.151113 | -66.6% |
| test_eval[timm_resnest-cpu-jit] | 0.117072 | 0.052848 | -54.9% |
| test_eval[timm_vision_transformer-cpu-jit] | 0.046048 | 0.023275 | -49.5% |
| test_eval[timm_vovnet-cpu-jit] | 0.213187 | 0.077482 | -63.7% |
| test_eval[vgg16-cpu-jit] | 0.044726 | 0.021998 | -50.8% |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85591
Approved by: https://github.com/jgong5, https://github.com/frank-wei, https://github.com/chunyuan-w