DeepSpeed
c2bb53f2 - TiledMLP + SequenceTiledCompute: improve the bs>1 use-case (#7422)

Commit
146 days ago
TiledMLP + SequenceTiledCompute: improve the bs>1 use-case (#7422) Improved TiledMLP and SequenceTiledCompute for bs>1 This PR: - extends the testing utils to add `CaptureStd*`, `CaptureLogger` context managers - extends the test to run both bs=1 and bs=2 - use an uneven seqlen to test varlen shards - flattens bs+seqlen dim, to avoid problems with grad tensor strides when bs>1 - mlp doesn't care for the bs dimension so using a pretend `bs*seqlen` seqlen instead and restoring the shape at the end for the grad. --------- Signed-off-by: Stas Bekman <stas@stason.org> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Author
Parents
Loading