pytorch
4fc19e1a - [optim][adam] use fastest impl whenever possible, add util (#93184)

Commit
1 year ago
[optim][adam] use fastest impl whenever possible, add util (#93184) This allows it so that ONLY when the users don't set anything for foreach or fused do we switch the default and cascades adam so that we default to fused, then foreach, then single-tensor. To clarify: * if the user puts True in foreach _only_, it will run the foreach implementation. * if the user puts True in fused _only_, it will run the fused implementation. * if the user puts True in foreach AND for fused, it will run the fused implementation. And: * if the user puts False in foreach _only_, it will run the single tensor implementation. * if the user puts False in fused _only_, it will still run the single tensor implementation. * if the user puts False in foreach AND for fused, it will run the single tensor implementation. I also didn't trust myself that much with the helper function, so I ran some local asserts on _default_to_fused_or_foreach. The only point left to really test is the type(p) -- torch.Tensor but I think the distributed tests will catch that in CI. ``` cuda_only_fp_list = [ torch.rand((1, 2), device="cuda", dtype=torch.float32), torch.rand((1, 2), device="cuda", dtype=torch.float64), torch.rand((1, 2), device="cuda", dtype=torch.float16), torch.rand((1, 2), device="cuda", dtype=torch.bfloat16), ] cuda_only_int_list = [ torch.randint(1024, (1, 2), device="cuda", dtype=torch.int64), ] cpu_list = [ torch.rand((1, 2), device="cpu", dtype=torch.float32), torch.rand((1, 2), device="cpu", dtype=torch.float64), torch.rand((1, 2), device="cpu", dtype=torch.float16), ] none_list = [None] # differentiable should always make it return false for both assert _default_to_fused_or_foreach([cuda_only_fp_list], True, True) == (False, False) assert _default_to_fused_or_foreach([cuda_only_fp_list], True, False) == (False, False) # cpu lists should always make it return false for both assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, True) == (False, False) assert _default_to_fused_or_foreach([cpu_list], False, True) == (False, False) assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, False) == (False, False) assert _default_to_fused_or_foreach([cpu_list], False, False) == (False, False) # has fused triggers correctly assert _default_to_fused_or_foreach([cuda_only_fp_list], False, True) == (True, False) assert _default_to_fused_or_foreach([cuda_only_fp_list], False, False) == (False, True) # ints always goes to foreach assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, True) == (False, True) assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, False) == (False, True) # Nones don't error assert _default_to_fused_or_foreach([cuda_only_fp_list, none_list], False, True) == (True, False) assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list, none_list], False, True) == (False, True) assert _default_to_fused_or_foreach([none_list], False, True) == (True, False) assert _default_to_fused_or_foreach([none_list], False, False) == (False, True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93184 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading