Enable TF32 support for cuBLAS (#40800)
Summary:
Benchmark on a fully connected network and torchvision models (time in seconds) on GA100:
| model | batch size | forward(TF32) | forward(FP32) | backward(TF32) | backward(FP32) |
|--------------------|------------|---------------|---------------|----------------|----------------|
| FC 512-128-32-8 | 512 | 0.000211 | 0.000321 | 0.000499 | 0.000532 |
| alexnet | 512 | 0.0184 | 0.0255 | 0.0486 | 0.0709 |
| densenet161 | 128 | 0.0665 | 0.204 | 0.108 | 0.437 |
| googlenet | 256 | 0.0925 | 0.110 | 0.269 | 0.326 |
| inception_v3 | 256 | 0.155 | 0.214 | 0.391 | 0.510 |
| mnasnet1_0 | 512 | 0.108 | 0.137 | 0.298 | 0.312 |
| mobilenet_v2 | 512 | 0.114 | 0.294 | 0.133 | 0.303 |
| resnet18 | 512 | 0.0722 | 0.100 | 0.182 | 0.228 |
| resnext50_32x4d | 256 | 0.170 | 0.237 | 0.373 | 0.479 |
| shufflenet_v2_x1_0 | 512 | 0.0463 | 0.0473 | 0.125 | 0.123 |
| squeezenet1_0 | 512 | 0.0870 | 0.0948 | 0.205 | 0.214 |
| vgg16 | 256 | 0.167 | 0.234 | 0.401 | 0.502 |
| wide_resnet50_2 | 512 | 0.186 | 0.310 | 0.415 | 0.638 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40800
Reviewed By: mruberry
Differential Revision: D22517785
Pulled By: ngimel
fbshipit-source-id: 87334c8935616f72a6af5abbd3ae69f76923dc3e