Workaround for MAGMA accessing illegal memory in batched cholesky (#50957)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50957
MAGMA has an off-by-one error in their batched cholesky implementation which is causing illegal memory access for certain inputs. The workaround implemented in this PR is to pad the input to MAGMA with 1 extra element.
**Benchmark**
Ran the script below for both before and after my PR and got similar results.
*Script*
```
import torch
from torch.utils import benchmark
DTYPE = torch.float32
BATCHSIZE = 512 * 512
MATRIXSIZE = 16
a = torch.eye(MATRIXSIZE, device='cuda', dtype=DTYPE)
t0 = benchmark.Timer(
stmt='torch.cholesky(a)',
globals={'a': a},
label='Single'
)
t1 = benchmark.Timer(
stmt='torch.cholesky(a)',
globals={'a': a.expand(BATCHSIZE, -1, -1)},
label='Batched'
)
print(t0.timeit(100))
print(t1.timeit(100))
```
*Results before*
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Single
2.08 ms
1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Batched
7.68 ms
1 measurement, 100 runs , 1 thread
```
*Results after*
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Single
2.10 ms
1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Batched
7.56 ms
1 measurement, 100 runs , 1 thread
```
Fixes https://github.com/pytorch/pytorch/issues/41394, https://github.com/pytorch/pytorch/issues/26996, https://github.com/pytorch/pytorch/issues/48996
See also https://github.com/pytorch/pytorch/issues/42666, https://github.com/pytorch/pytorch/pull/26789
TODO
---
- [x] Benchmark to check for perf regressions
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D26050978
Pulled By: heitorschueroff
fbshipit-source-id: 7a5ba7e34c9d74b58568b2a0c631cc6d7ba63f86