[cpu] enable bfloat16 and refactor for flash attention (#104863)
Feature RFC: https://github.com/pytorch/rfcs/pull/56.
The support for BF16 is added in flash attention CPU kernel, for both forward and backward paths.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104863
Approved by: https://github.com/jgong5, https://github.com/drisspg
ghstack dependencies: #104583, #104584, #103826, #104693