move is_checkpointable call reducing torch.compile Graph breaks (#5759)
We have encountered a performance issue when running torch compile on a
model utilizing
the pipeline engine (Mixtral).
The issue was found to be the is_checkpointable function which is called
in the engine's forward function.
This function creates a graph break when using torch.compile leading to
decreased performance (particularly since this happens in every forward
call). We propose a change in the way is_checkpointable is checked by
precomputing and storing its value before the forward call and accessing
the stored values in the forward function.
given this change the graph break in the forward call is avoided which
should lead to better performance for torch compile.
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>