Numerical stability of embedding kernels (#22401)
Summary:
Address the issue raised in https://github.com/pytorch/pytorch/issues/22377.
The PR https://github.com/pytorch/pytorch/issues/22016 introduces a temporary tensor of weights `grad_weight_per_segment` of the same dtype as the end result, which can be a problem when using `float16`.
In this PR, it now use a `float32` temporary tensor when the input is `float16`.
ngimel, can I get you to review? I think I have fixed the issues you have pointed out.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22401
Differential Revision: D16077319
Pulled By: mrshenli
fbshipit-source-id: 7cfad7f40b4d41a244052baa2982ab51bbbd7309