Call MPSAllocator callbacks when allocation fails. (#94133)
Fixes #87374
@kulinseth and @albanD This makes the MPSAllocator call the MPSAllocatorCallbacks when getting a free buffer and a first try on allocating fails. User can register callbacks that might free a few buffers and an allocation will be retried.
The reason why we need the `recursive_mutex` is that since callbacks are supposed to free memory, they will eventually call free_buffer() that will lock the same `mutex` that's used for allocation. This approach is similar what's used with the `FreeMemoryCallback` in the `CUDACachingAllocator`.
This PR tries to be as minimal as possible, but there could be some additional improvements cleanups, like:
- In current main, there's no way callbacks can be called, so we could probably rename the callback registry to something reflect the same naming in the CudaAllocator:
https://github.com/pytorch/pytorch/blob/996cc1c0d09a7bc6ad33441c08961226005c69bf/c10/cuda/CUDACachingAllocator.h#L14-L24
- Review the EventTypes here:
https://github.com/pytorch/pytorch/blob/996cc1c0d09a7bc6ad33441c08961226005c69bf/aten/src/ATen/mps/MPSAllocator.h#L18-L23
- And IMHO a nice improvement would be if callbacks could be aware of AllocParams, so they can decide to be more agressive or not depending on how much memory is requested. So I'd pass AllocParams in the signature of the executeCallback instance:
https://github.com/pytorch/pytorch/blob/996cc1c0d09a7bc6ad33441c08961226005c69bf/aten/src/ATen/mps/MPSAllocator.h#L25
Let me know if you think we could sneak those changes into this PR or if it's better to propose them in other smaller PR's.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94133
Approved by: https://github.com/kulinseth, https://github.com/razarmehr, https://github.com/albanD