DeepSpeed
6d0dbf86 - move is_checkpointable call reducing torch.compile Graph breaks (#5759)

Commit
1 year ago
move is_checkpointable call reducing torch.compile Graph breaks (#5759) We have encountered a performance issue when running torch compile on a model utilizing the pipeline engine (Mixtral). The issue was found to be the is_checkpointable function which is called in the engine's forward function. This function creates a graph break when using torch.compile leading to decreased performance (particularly since this happens in every forward call). We propose a change in the way is_checkpointable is checked by precomputing and storing its value before the forward call and accessing the stored values in the forward function. given this change the graph break in the forward call is avoided which should lead to better performance for torch compile. Co-authored-by: Heyang Qin <heyangqin@microsoft.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Parents
Loading