Fallback to eager for float8 ops in inductor (#108293)
# Summary
As a stop gap to supporting the FP8 Dtype within inductor we would like to fallback to eager. Currently there are 3 ops that are needed for this:
`_scaled_mm` ( matmul for fp8 types)
`clone` (for creating new copies of fp8 tensors)
`to` ( for converting to and from fp8 types).
This PR registers a fallback for _scaled_mm. And adds fp8 to trigger `unsupported_input_tensor`
Prior to these changes this was failing with:
``` Shell
File "/home/drisspg/meta/pytorch/torch/_inductor/codegen/triton_utils.py", line 11, in signature_of
tye = JITFunction._type_of(arg.dtype)
File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/triton/runtime/jit.py", line 229, in _type_of
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: 'float8_e4m3fn'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108293
Approved by: https://github.com/peterbell10