pytorch
a7cf04ec - Workaround for MAGMA accessing illegal memory in batched cholesky (#50957)

Commit
4 years ago
Workaround for MAGMA accessing illegal memory in batched cholesky (#50957) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50957 MAGMA has an off-by-one error in their batched cholesky implementation which is causing illegal memory access for certain inputs. The workaround implemented in this PR is to pad the input to MAGMA with 1 extra element. **Benchmark** Ran the script below for both before and after my PR and got similar results. *Script* ``` import torch from torch.utils import benchmark DTYPE = torch.float32 BATCHSIZE = 512 * 512 MATRIXSIZE = 16 a = torch.eye(MATRIXSIZE, device='cuda', dtype=DTYPE) t0 = benchmark.Timer( stmt='torch.cholesky(a)', globals={'a': a}, label='Single' ) t1 = benchmark.Timer( stmt='torch.cholesky(a)', globals={'a': a.expand(BATCHSIZE, -1, -1)}, label='Batched' ) print(t0.timeit(100)) print(t1.timeit(100)) ``` *Results before* ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400> Single 2.08 ms 1 measurement, 100 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400> Batched 7.68 ms 1 measurement, 100 runs , 1 thread ``` *Results after* ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400> Single 2.10 ms 1 measurement, 100 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400> Batched 7.56 ms 1 measurement, 100 runs , 1 thread ``` Fixes https://github.com/pytorch/pytorch/issues/41394, https://github.com/pytorch/pytorch/issues/26996, https://github.com/pytorch/pytorch/issues/48996 See also https://github.com/pytorch/pytorch/issues/42666, https://github.com/pytorch/pytorch/pull/26789 TODO --- - [x] Benchmark to check for perf regressions Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D26050978 Pulled By: heitorschueroff fbshipit-source-id: 7a5ba7e34c9d74b58568b2a0c631cc6d7ba63f86
Parents
Loading