pytorch
1211ceea - [MPS] Fix issues with max_pool2d (#95325)

Commit
1 year ago
[MPS] Fix issues with max_pool2d (#95325) * [MPS] Fix upsample for NHWC output (#94963) Fixes https://github.com/huggingface/diffusers/issues/941 **Before**: <img width="1144" alt="Screenshot 2023-02-15 at 8 11 53 PM" src="https://user-images.githubusercontent.com/104024078/219266709-6a77636a-2fc0-4802-b130-85069b95953f.png"> **After**: <img width="1144" alt="Screenshot 2023-02-15 at 8 12 02 PM" src="https://user-images.githubusercontent.com/104024078/219266694-ea743c02-fb55-44f1-b7d6-5946106527c3.png"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/94963 Approved by: https://github.com/razarmehr * [MPS] Move max_pool2d to mps dispatch key (#90772) Related issue: #77394 This PR also modifies some assertions in the codegen, an explanatory comment for it has been added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90772 Approved by: https://github.com/albanD * [MPS] Convert output back to ChannelsLast for MaxPool2D (#94877) Since we re-stride the indices and output in MPS pooling from ChannelsLast to Contiguous, we need to convert the results back to ChannelsLast. This will fix the failure with test_memory_format with MaxPool2D in test_modules.py. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94877 Approved by: https://github.com/kulinseth, https://github.com/DenisVieriu97 --------- Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com> Co-authored-by: Li-Huai (Allan) Lin <qqaatw@gmail.com> Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
Author
Parents
Loading