[ONNX] Fix numerical errors in softmax when dim is not last dimension (#37326)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/34585.
This PR improves the workaround for the problem of different semantics between ONNX softmax and Pytorch softmax.
In Pytorch the `dim` parameter specifies over which dimension normalize the values. ONNX on the other hand always coerces the input into a 2D tensor and the `axis` parameter specifies which dimensions represent rows and columns of the resulting tensor. As a result, only when we are normalizing the last dimension (`dim == ndim - 1`) semantics are the same.
Previously this was handled by recognizing the `dim == ndim - 1` case and using `softmax` for that. All other cases used a fallback path of explicit invocations of exp, reducesum and div operators to compute the result. Unfortunately, this results in numeric errors when input values are large: the result of exp will produce infinity on both numerator and denumerator and the division of that will result in NaN.
This can be improved by transposing the input tensor so that we can reuse ONNX softmax.
Similar approach has been applied to `logsoftmax` function in https://github.com/pytorch/pytorch/issues/30433.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37326
Reviewed By: hl475
Differential Revision: D21389712
Pulled By: houseroad
fbshipit-source-id: 554fd1b98231a28984c30c7e7abd3c0643386ff7