pytorch
075a4944 - [MPS] Allow `float16` input to float32 `LayerNorm` (#96430)

Commit
1 year ago
[MPS] Allow `float16` input to float32 `LayerNorm` (#96430) Only for forward pass Subset of https://github.com/pytorch/pytorch/pull/96208 Create constant with scalar using `input_mps_dtype` and use `reciprocalWithTensor` instead of `divisionWithPrimaryTensor:1.0 secondaryTensor:` Fixes https://github.com/pytorch/pytorch/issues/96113 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96430 Approved by: https://github.com/kulinseth
Author
Committer
Parents
Loading