make flash_attn_bw impl correct w.r.t. meta when k and v have different strides (#119500)
`dv = at::empty_like(k)` and `dv = at::empty_like(v)` can be materially different, because `empty_like` tries to preserve the strides of the input when possible. So if `k` is contiguous, but `v`, is transposed, then before this PR, `dv` would be computed to be contiguous.
Alternatively, we could change the meta implementation of `aten._scaled_dot_product_flash_attention` to this:
```
grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
grad_v = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
return grad_q, grad_k, grad_v
```
But (I think?) the logic in the sdpa backward impl was a typo.
I noticed this because changing the meta formula as above was enough to fix the issue with the `aot_eager` backend in this [link](https://github.com/pytorch/pytorch/issues/116935#issuecomment-1914310523).
A minimal repro that I made looks like this:
```
import torch
# in this repro, "grad_out" and "value" are transposed tensors,
# but "key" and "value" are contiguous
a = torch.randn(2, 513, 16, 64, dtype=torch.float16, device='cuda').transpose(1, 2)
b = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
c = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
d = torch.randn(2, 513, 16, 64, dtype=torch.float16, device='cuda').transpose(1, 2)
e = torch.randn(2, 16, 513, 64, dtype=torch.float16, device='cuda')
f = torch.randn(2, 16, 513, device='cuda')
g = None
h = None
i = 513
j = 513
k = 0.0
l = False
m = torch.tensor(1, dtype=torch.int64)
n = torch.tensor(1, dtype=torch.int64)
out1_ref, out2_ref, out3_ref = torch.ops.aten._scaled_dot_product_flash_attention_backward(a, b, c, d, e, f, g, h, i, j, k, l, m, n, scale=0.125)
from torch._meta_registrations import meta__scaled_dot_product_flash_backward
out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(a, b, c, d, e, f, g, h, i, j, k, l, m, n, scale=0.125)
# prints True True
print(out1_ref.is_contiguous())
print(out1_test.is_contiguous())
# prints True True
print(out2_ref.is_contiguous())
print(out2_test.is_contiguous())
# prints True False
print(out3_ref.is_contiguous())
print(out3_test.is_contiguous())
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119500
Approved by: https://github.com/drisspg, https://github.com/ezyang, https://github.com/Skylion007