TiledMLP + SequenceTiledCompute: improve the bs>1 use-case (#7422)
Improved TiledMLP and SequenceTiledCompute for bs>1
This PR:
- extends the testing utils to add `CaptureStd*`, `CaptureLogger`
context managers
- extends the test to run both bs=1 and bs=2
- use an uneven seqlen to test varlen shards
- flattens bs+seqlen dim, to avoid problems with grad tensor strides
when bs>1 - mlp doesn't care for the bs dimension so using a pretend
`bs*seqlen` seqlen instead and restoring the shape at the end for the
grad.
---------
Signed-off-by: Stas Bekman <stas@stason.org>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>