pytorch
ef13fde2 - Increase mem eff backward performance (#101847)

Commit
1 year ago
Increase mem eff backward performance (#101847) # Summary This is another upstream which is much smaller than the previous. This bumps the kernel versions from xformers Current: [6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3](https://github.com/facebookresearch/xformers/commit/6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3) With this PR: [1d635e193e169fc677b2e7fa42dad7ebe88eec9e](https://github.com/facebookresearch/xformers/commit/1d635e193e169fc677b2e7fa42dad7ebe88eec9e) ### Notable Changes: - Drastically improve the BW pass in multiple cases (especially when B*numHeads < 100) - H100 Support: *Warning* While these kernels have been added, we don't have the CI/CD machines to test. - Enables a deterministic mode. ## Specific Changes - Updates to the backward kernel. - Added num_splits_key which we hard code to -1. (This is a another performance knob that we set to the heuristic) - Update gen_code and kernels to produce h100 instantiations. ### Due Diligence Checks: * CUDA_lib size: No changes in size #### Peformance * Micro Benchmark: (batch_size: 1, num_heads=25, seq_len=4096, embed_dim = 64 | grid:[1,25,1]block: [128,1,1]) * MemEfficientAttention Backward Kernel: 27.972 ms * After the updated Xformers code(https://github.com/pytorch/pytorch/pull/100583): 23.958 ms * With this PR: 4.085 ms * Ran micro benchmarks on sdpa_forw().sum().backward() over a range of dtypes, and input shapes * Geo_mean increase -> 1.17x * Max increase -> 2.95x * min_increase -> 0.8x Pull Request resolved: https://github.com/pytorch/pytorch/pull/101847 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading