implement ordering (#91362)
# Summary
In some cases, dependent on input, flash-attention is not the fastest fused kernel and memory-efficient attention is better. This implements a simple heuristic function for deciding the ordering of kernel functions. This was based off of the xformer function found here: https://github.com/fairinternal/xformers/blob/15bff4986c3a4376176a4e6fa3dc0f2a120fa0bb/xformers/ops/fmha/dispatch.py#L13
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91362
Approved by: https://github.com/cpuhrsch