Expand NT subclass to support SAM (#109123)
This PR contains the changes needed to support using the NT jagged subclass within SAM. Note that a NT with multiple ragged dims is still required at the extremes for inputs / outputs, but the internal computation generally involves a single ragged dim, making the jagged layout usable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109123
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer