pytorch
ca791b69 - [MPS] Add higher order derivatives warning to max_pool2d (#98582)

Commit
1 year ago
[MPS] Add higher order derivatives warning to max_pool2d (#98582) The higher order derivatives calculations of `max_pool2d` require indices provided, but `mps_max_pool2d` kernel doesn't calculate it. If we calculate indices during back propagations afterwards, that would be expensive and unnecessary since users can directly call `max_pool2d` with `return_indices=True`, which calculates `indices` along. This PR adds a warning for it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98582 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading