onnxruntime
9f68a27c - [ORTModule] Handle Cast on Constant Number on Triton Code-gen (#19321)

Commit
2 years ago
[ORTModule] Handle Cast on Constant Number on Triton Code-gen (#19321) When using scaled_dot_product_attention on float16 type, the exported graph has Sqrt(float16(constant)), which cannot be ConstantFold in ORT because Sqrt CPU kernel doesn't support float16. This causes Triton code-gen generates code like: result = 128.0.to(tl.float32) This code cannot be compiled because .to() cannot be applied to constant. This PR is to handle such case that constant number will not do the Cast.
Author
Parents
Loading