pytorch
035310c5 - Handle shared memory cases in MathBithFallback (#63602)

Commit
3 years ago
Handle shared memory cases in MathBithFallback (#63602) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63602 This PR fixes the case when a read and write is performed on a memory shared between mutable and (or) non-mutable arguments. Example: ``` a=torch.tensor([1+1j]) b=a.conj() b.add_(a) # should return tensor([2]) but returns tensor ([2-2j]) ``` The issue here is that in the conjugate fallback, we resolve the conjugation in-place for mutable arguments which can be a problem as shown above in the case when other input arguments share memory with the mutable argument(s). This PR fixes this issue by: 1. first scanning through the operator input arguments and creating a vector of mutable arguments that have the conj bit set to `True` (and accordingly setting the flag `check_for_alias_with_mut_arg ` to `True` or `False`). 2. Iterating through all the arguments. At this time we only look at the non-mutable arguments. If `check_for_alias_with_mut_arg` is set to `True`, then we iterate through `mutable_inputs` to check if the current arg tensor in question doesn't alias any of the entries in `mutable_inputs`. If yes, then we clone the non-mutable tensor arg, else we resolve the conjugation as before. 3. Now we look through the mutable_inputs vector (which contains only mutable input tensors with conj bit set to `True`). We in-place conjugate each of the entries in the vector. 4. Do the computation. 5. Re-conjugate the mutable argument tensors. NOTE: `TensorLists` are not fully handled in ConjugateFallback. Please see the in-line comment for more details. Fixes https://github.com/pytorch/pytorch/issues/59943 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D30466905 Pulled By: anjali411 fbshipit-source-id: 58058e5e6481da04a12d03f743c1491942a6cc9b
Author
Parents
Loading