Fix #7733: Replace torch.sqrt with math.sqrt in scale_lr for sqrt method (#7735)
Fixes #7733
When using lr_scaling_method='sqrt' with dynamic batching, the scale_lr
function was failing with TypeError because torch.sqrt expects a Tensor
but receives a Python float from batch_size/base_batch_size division.
Changed torch.sqrt to math.sqrt which correctly handles Python floats.
This fixes the issue where training would fail with: TypeError: sqrt():
argument 'input' (position 1) must be Tensor, not float
---------
Signed-off-by: Rakshit-gen <sisodiarakshit456@gmail.com>
Co-authored-by: Xinyu Lian <lian7@illinois.edu>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>