Fix norm decomp when dtype is passed in (#89508)
Fix for https://github.com/pytorch/torchdynamo/issues/1889. The wrapper was doing a downcast even when the dtype was explicitly passed in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89508
Approved by: https://github.com/anijain2305