fix the scale dot attention doc (#120859)
Fixes #120810
The code verifies the broadcast behavior (from the issue),
```
import torch
B = 3
S = 5
L = 7
E = 16
EV = 32
additional_batches = [2, 4]
query_shape = [B] + additional_batches + [L, E]
key_shape = [B] + additional_batches + [S, E]
value_shape = [B] + additional_batches + [S, EV]
query = torch.rand(*query_shape)
key = torch.rand(*key_shape)
value = torch.rand(*value_shape)
mask = torch.zeros((1, 1, S), dtype=torch.bool)
mask[:, :, S // 2 :] = True
# query.to("cuda")
# key.to("cuda")
# value.to("cuda")
# mask.to("cuda")
attention = torch.nn.functional.scaled_dot_product_attention(query, key, value, mask)
print(f"query shape = {query.shape}")
print(f"key shape = {key.shape}")
print(f"value shape = {value.shape}")
print(f"mask shape = {mask.shape}")
print(f"attention shape = {attention.shape}")
#in both CPU and cuda, output shape is:
# query shape = torch.Size([3, 2, 4, 7, 16])
# key shape = torch.Size([3, 2, 4, 5, 16])
# value shape = torch.Size([3, 2, 4, 5, 32])
# mask shape = torch.Size([1, 1, 5])
# attention shape = torch.Size([3, 2, 4, 7, 32])
## test add is broadcasting mask to query@(key.mT)
res = query@(key.mT)
print(res.shape)
res2 = torch.add(res, mask)
print(res2.shape)
```
At code level, in the default backend,
https://github.com/pytorch/pytorch/blob/ab38354887fe86e611f6f5bef0b9d7cf72e27d8b/aten/src/ATen/native/transformers/attention.cpp#L735
the add operation is broadcasting the `attn_mask` to `auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);`
- Changed the doc in [torch/nn/functional.py](https://github.com/pytorch/pytorch/pull/120859/files#diff-c358c214f663ba0c8b9c6846fbe0042fa29494cf02fe4714a17dcd0d268b035b).
- Also fixed a few inconsistencies in the cpp comments.
@mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120859
Approved by: https://github.com/drisspg