rocblas alt impl during backward pass only (#13352)
On AMD Instinct MI200 GPUs, the FP16 and BF16 V_DOT2 and MFMA matrix
instructions flush input and output denormal values to zero. When
training using FP16 precision, some models may fail to converge with
FP16 denorms flushed to zero. The affected instructions are only used by
rocBLAS (GEMM) and MIOpen (convolution) kernels; all other onnxruntime
operations will not encounter this behavior. All other supported AMD
GPUs will not encounter this behavior.
rocBLAS and MIOpen provide alternate implementations for affected FP16
operations. Alternate implementations for BF16 operations are not
provided; BF16 numbers have a larger dynamic range than FP16 numbers and
are less likely to encounter denormal values. For the FP16 alternate
implementations, FP16 input values are cast to an intermediate BF16
value and then cast back to FP16 output after the accumulate FP32
operations. In this way, the input and output types are unchanged.
Denormal values more frequently occur in the backward pass of training
during gradient calculation. Therefore, it is necessary to track when
the backward pass of training is executing. For the ROCm EP only, the
`__backwardpass` attribute is added to all Nodes after the YieldOp is
detected. This takes place in a level1 graph optimization pass. The
attribute is forwarded to any newly created FusedMatMul Nodes. In
addition, the scope-based helper class `BackwardPassGuard` is provided
to toggle state for rocblas. This behavior of using the alternate
implementations during the backward pass is made automatic with this PR.
This default behavior can be overridden using environment variables,
ROCBLAS_INTERNAL_FP16_ALT_IMPL and
MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL. The behavior of these
environment variables is as follows:
| | forward | backward |
|--------------|-----------|-----------|
| Env unset | original | alternate |
| Env set to 1 | alternate | alternate |
| Env set to 0 | original | original |
See also:
https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices