Fixes nan with large bf16 values (#122135)
Fixes #121558
Performance on main:
``` Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.608132004970683 | 65.90210803551601 |
| 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.75877740024589 | 64.83824399765581 |
| 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 16.465420153690506 | 67.6770955324173 |
| 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 17.398148600477725 | 68.19829455344006 |
| 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 29.053532000398263 | 99.58901099162175 |
| 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 27.826815698063 | 98.05690299253911 |
| 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 49.89655229728669 | 178.24282555375248 |
| 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 48.840098950313404 | 174.5950729819015 |
| 1 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 505.66218036692584 | 1865.9265094902366 |
| 1 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 295.0534054543823 | 967.3831606050952 |
| 1 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.496030446141958 | 55.11070846114308 |
| 1 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.47399884648621 | 55.452342028729625 |
| 1 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 13.216444296995178 | 55.14447903260589 |
| 1 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 12.763233599252999 | 55.142355500720434 |
| 1 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 19.409965351223946 | 74.9107634765096 |
| 1 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 19.02470579952933 | 74.84168506925926 |
| 1 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 46.37695319834165 | 172.19150450546294 |
| 1 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 45.225963747361675 | 185.19691249821335 |
| 1 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 634.3090848531574 | 2249.057865119539 |
| 1 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 320.47313248040155 | 1053.0515247955916 |
| 4 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 13.448987301671878 | 63.63581650657579 |
| 4 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 12.509283400140703 | 63.059300999157124 |
| 4 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 19.71098779467866 | 105.55780201684684 |
| 4 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 18.264925852417946 | 105.12311349157244 |
| 4 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 45.218703348655254 | 222.87272597895935 |
| 4 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 43.55393464793451 | 230.63290398567915 |
| 4 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 134.02968645095825 | 514.6893998607993 |
| 4 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 157.13709802366793 | 624.5892751030624 |
| 4 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 1776.7079547047617 | 6353.551096981391 |
| 4 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1143.6000745743513 | 3811.8767354171723 |
| 4 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.717129248427227 | 55.35991647047922 |
| 4 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.746983398916198 | 55.76716404175386 |
| 4 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 17.255573300644752 | 106.47456656442955 |
| 4 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 16.46409669774584 | 108.07770595420152 |
| 4 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 46.63354124641045 | 213.74862996162847 |
| 4 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 47.01801469782367 | 240.78139301855117 |
| 4 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 127.76448752265424 | 508.08745552785695 |
| 4 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 168.6308984644711 | 667.2996102133766 |
| 4 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 2268.1598202325404 | 7727.2648515645415 |
| 4 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1242.8469699807465 | 4161.965740495361 |
| 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 14.340955897932872 | 93.72280450770633 |
| 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 13.25262250029482 | 93.2030284893699 |
| 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 27.598425600444898 | 183.23776399483904 |
| 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 26.362583553418514 | 183.51862096460536 |
| 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 84.52303148806094 | 383.50319798337296 |
| 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 89.41743348259479 | 432.5502900755964 |
| 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 217.76640450116247 | 943.9354750793427 |
| 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 303.0781910638325 | 1225.4394043702632 |
| 8 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 3470.8542854059488 | 12194.579601055011 |
| 8 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2268.1174043100327 | 7608.0941944383085 |
| 8 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.289720651460811 | 95.88620596332476 |
| 8 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.618648946750909 | 95.56685149436818 |
| 8 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 31.567946751601994 | 180.62468653079122 |
| 8 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 28.611703700153157 | 189.4215695792809 |
| 8 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 84.11306998459621 | 385.25596749968827 |
| 8 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 93.82540901424363 | 455.77428903197875 |
| 8 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 226.80530551588163 | 965.8026450779289 |
| 8 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 327.4116570246406 | 1312.5067745568228 |
| 8 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 4445.5064804060385 | 15020.768146496266 |
| 8 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2433.0302356975153 | 8300.016750581563 |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```
Performance on this branch:
```Markdown
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.783618393586949 | 65.59692794689909 |
| 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 12.064015300711617 | 56.99719698168337 |
| 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 16.629025398287922 | 68.65267595276237 |
| 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 17.462356004398313 | 68.35797848179936 |
| 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 29.5476081490051 | 101.22994752600789 |
| 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 28.395320149138573 | 98.62275794148445 |
| 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 50.50016101449728 | 181.4357690163888 |
| 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 49.450615647947416 | 175.86063902126625 |
| 1 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 506.06461532879626 | 1866.0613044630736 |
| 1 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 299.9336270149797 | 976.4662646921353 |
| 1 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.45752210286446 | 58.79682704107836 |
| 1 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.407129396684468 | 58.14061599085107 |
| 1 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 13.822759891627355 | 56.56979401828722 |
| 1 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 13.39154909946956 | 56.7130644340068 |
| 1 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 20.282494352431968 | 77.29688903782517 |
| 1 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 19.899454596452415 | 75.4446149803698 |
| 1 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 48.494275606935844 | 177.5322465109639 |
| 1 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 46.84524350450374 | 189.1778860008344 |
| 1 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 635.1026654010639 | 2248.0451600858937 |
| 1 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 335.1591735263355 | 1080.4320796160027 |
| 4 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 13.63953539985232 | 65.50709309522063 |
| 4 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 12.858113402035087 | 63.021871959790595 |
| 4 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 19.98318645055406 | 105.87883047992364 |
| 4 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 18.619045056402683 | 104.90188701078296 |
| 4 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 45.91175540117546 | 226.00732848513871 |
| 4 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 44.39614630537107 | 232.39317198749632 |
| 4 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 135.5409600073472 | 522.7949097752571 |
| 4 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 158.79383607534692 | 628.5856699105352 |
| 4 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 1775.9978299727663 | 6343.203847063706 |
| 4 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1160.680354805663 | 3842.235009651631 |
| 4 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 11.553713708417488 | 65.50691701704638 |
| 4 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.486379051348194 | 56.9980075233616 |
| 4 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 17.56585600087419 | 107.89892700267956 |
| 4 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 16.828144202008843 | 109.05519902007653 |
| 4 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 48.23235589428805 | 217.8974545095116 |
| 4 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 49.09284680034033 | 244.73925953498107 |
| 4 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 134.77827049791813 | 522.7259948151186 |
| 4 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 176.60772847011688 | 681.5171707421541 |
| 4 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 2267.821540008299 | 7720.425300067291 |
| 4 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 1295.3941145678982 | 4272.425139788538 |
| 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 14.514714101096615 | 94.2192979855463 |
| 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 13.553097198018804 | 93.244242540095 |
| 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 27.95821905019693 | 185.0469880155288 |
| 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 26.709681446664035 | 184.22623950755226 |
| 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 85.85420495364815 | 388.3417735341937 |
| 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 89.97473795898259 | 434.4228169647977 |
| 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 220.6919804448262 | 958.9654899900779 |
| 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 306.55586952343583 | 1233.2170095760375 |
| 8 | 16 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 3470.7326447824016 | 12183.611298678443 |
| 8 | 16 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2299.064100370742 | 7669.618452200666 |
| 8 | 32 | 128 | 128 | 2048 | True | torch.bfloat16 | 12.427107692928985 | 96.96270158747211 |
| 8 | 32 | 128 | 128 | 2048 | False | torch.bfloat16 | 11.856995843118057 | 96.38117247959599 |
| 8 | 32 | 256 | 256 | 2048 | True | torch.bfloat16 | 32.9956392000895 | 182.52741603646427 |
| 8 | 32 | 256 | 256 | 2048 | False | torch.bfloat16 | 29.397601098753512 | 191.0755339777097 |
| 8 | 32 | 512 | 512 | 2048 | True | torch.bfloat16 | 89.06024845782667 | 392.2585004474967 |
| 8 | 32 | 512 | 512 | 2048 | False | torch.bfloat16 | 97.78487798757851 | 462.07307645818213 |
| 8 | 32 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 240.521906001959 | 992.4693452194335 |
| 8 | 32 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 341.98952303268015 | 1339.2950996058062 |
| 8 | 32 | 4096 | 2048 | 2048 | True | torch.bfloat16 | 4445.311005110853 | 15001.030603889374 |
| 8 | 32 | 4096 | 2048 | 2048 | False | torch.bfloat16 | 2535.9767401823774 | 8528.990152990447 |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```
```
{'avg_forward_time_nan_fix': 399.7900972732653,
'avg_backward_time_nan_fix': 1409.652114014413,
'avg_forward_time_main_branch': 394.6807206988645,
'avg_backward_time_main_branch': 1399.4055472857629,
'geo_mean_nan_fix': 150.95049601244946,
'geo_mean_main_branch': 148.3381648508822}
```
The y axis is wrong and is micro seconds but the relative comparison still works
<img width="790" alt="Screenshot 2024-03-18 at 3 34 15 PM" src="https://github.com/pytorch/pytorch/assets/32754868/ca278c15-b815-4535-bdcd-07e522055466">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122135
Approved by: https://github.com/cpuhrsch