pytorch
11aab72d - [SDPA] Add an optional scale kwarg (#95259)

Commit
1 year ago
[SDPA] Add an optional scale kwarg (#95259) # Summary This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention() The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well. Will reduce the complexity of: #94729 and has been asked for by a couple of users. # Review Highlights - As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right? - I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename. - 'scale' is interpreted as `Q@K.T * (scale)` Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading