pytorch
4acdc446 - [MPS] Fix batch norm for NHWC (#94760)

Commit
1 year ago
[MPS] Fix batch norm for NHWC (#94760) Fixes `test_modules.py` batch norm NHWC testcases: - `test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32` - `test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94760 Approved by: https://github.com/kulinseth
Author
Committer
Parents
Loading