GPU Mixed precision training for WMT
This also fixes NaN issues in the attention module when using float16.
Masking is now implemented using a select instead of adding a large negative. This avoids infinities And potential gradient leakage. In particular for float16 which has a narrow range.
PiperOrigin-RevId: 395646244