pytorch
a1f15fb9 - [MPS] Fix batchnorm forward and backward pass (#94351)

Commit
1 year ago
[MPS] Fix batchnorm forward and backward pass (#94351) Fixes batchnorm forward/backward pass and layer_norm: Batchnorm Forward pass: ``` - fix batch_norm_mps_out key - return 1/sqrt(var+epsilon) instead of var - return empty tensor for mean and var if train is not enabled - remove native_batch_norm from block list ``` Batchnorm Backward pass: ``` - add revert caculation for save_var used in backward path - add backward test for native_batch_norm and _native_batch_norm_legit ``` Layer norm: ``` - remove the duplicate calculation from layer_norm_mps - enable native_layer_norm backward test - raise atol rtol for native_layer_norm ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94351 Approved by: https://github.com/razarmehr
Author
Committer
Parents
Loading