DeepSpeed
d7d4eebc - Fix #7733: Replace torch.sqrt with math.sqrt in scale_lr for sqrt method (#7735)

Commit
54 days ago
Fix #7733: Replace torch.sqrt with math.sqrt in scale_lr for sqrt method (#7735) Fixes #7733 When using lr_scaling_method='sqrt' with dynamic batching, the scale_lr function was failing with TypeError because torch.sqrt expects a Tensor but receives a Python float from batch_size/base_batch_size division. Changed torch.sqrt to math.sqrt which correctly handles Python floats. This fixes the issue where training would fail with: TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not float --------- Signed-off-by: Rakshit-gen <sisodiarakshit456@gmail.com> Co-authored-by: Xinyu Lian <lian7@illinois.edu> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Author
Parents
Loading