pytorch
ff6d2a6d - Add mem efficient backward (#88856)

Commit
2 years ago
Add mem efficient backward (#88856) # Registers the derivative for mem efficient backward - Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32 - I also made updates based off of Xformer main branch and flash-attention cutlass branch. - This will enable the fused backward to be called for scaled dot product attention Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading