[decomp] Add tests for different dtypes to SDPA decomposition (#119239)
Summary: As titled. Skipping torch.bfloat16 because for some reason the
difference is 0.01.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119239
Approved by: https://github.com/drisspg