pytorch
bbc3cc67 - [CUDA graphs] [BC-breaking] Makes torch.cuda.amp.GradScaler scale updates in-place for better composability with graph capture (#55562)

Commit
3 years ago
[CUDA graphs] [BC-breaking] Makes torch.cuda.amp.GradScaler scale updates in-place for better composability with graph capture (#55562) Summary: I'd like the following pattern (a natural composition of Amp with full fwd+bwd capture) to work: ```python # Create "static_input" with dummy data, run warmup iterations, # call optimizer.zero_grad(set_to_none=True), then g = torch.cuda._Graph() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): optimizer.zero_grad(set_to_none=True) g.capture_begin() with autocast(): out = model(static_input) loss = loss_fn(out) scaler.scale(loss).backward() g.capture_end() torch.cuda.current_stream().wait_stream(s) # Training loop: for b in data: # optimizer.zero_grad() deliberately omitted, replay()'s baked-in backward will refill statically held .grads static_input.copy_(b) g.replay() scaler.step(optimizer) scaler.update() ``` Right now `GradScaler` can't work with this pattern because `update()` creates the scale tensor for the next iteration out of place. This PR changes `update()` to act in place on a long-lived scale tensor that stays static across iterations. I'm not sure how this change affects XLA (see https://github.com/pytorch/pytorch/pull/48570), so we shouldn't merge without approval from ailzhang yaochengji. Tagged bc-breaking because it's a change to the amp update utility function in native_functions.yaml. The function was never meant to be user-facing though. Pull Request resolved: https://github.com/pytorch/pytorch/pull/55562 Reviewed By: zou3519 Differential Revision: D28046159 Pulled By: ngimel fbshipit-source-id: 02018c221609974546c562f691e20ab6ac611910
Author
Parents
Loading