pytorch
e985503d - [NNC] Fix an issue with half-scalar vars coerced to float (Take 2) (#47448)

Commit
4 years ago
[NNC] Fix an issue with half-scalar vars coerced to float (Take 2) (#47448) Summary: Take 2 of this fix, I removed the repro from the issue which is a bit flaky due to parallelism. It broke on Windows but isn't specific to Windows or this fix, I think. I'll make sure all the tests pass this time (cc zou3519). Fixes an issue where fp16 scalars created by the registerizer could be referenced as floats - causing invalid conversions which would crash in the NVRTX compile. I also noticed that we were inserting patterns like float(half(float(X))) and added a pass to collapse those down inside the CudaHalfScalarRewriter. Fixes https://github.com/pytorch/pytorch/issues/47138 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47448 Reviewed By: glaringlee Differential Revision: D24765070 Pulled By: nickgg fbshipit-source-id: 5297e647534d53657bef81f4798e8aa6a93d1fbd
Author
Parents
Loading