pytorch
a17a7ccc - [MPS] LogSoftmax numerical stability (#95091)

Commit
2 years ago
[MPS] LogSoftmax numerical stability (#95091) Fixes #94043 Calculations are now consistent with numericaly stable formula and CPU: $LogSoftmax(X, \dim) = X - \max(X, \dim) - \log(sum(X - \max(X, \dim), \dim))$ @malfet Pull Request resolved: https://github.com/pytorch/pytorch/pull/95091 Approved by: https://github.com/malfet, https://github.com/kulinseth
Author
Committer
Parents
  • aten/src/ATen/native/mps/operations
    • File
      Activation.mm
  • test
    • File
      test_mps.py
Loading