[spmd] expose input_batch_dim to DataParallelMode (#99899)
This PR exposes the input batch dim to the DataParallelMode so that
we could have explicit control of which input dim is batch dim
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99899
Approved by: https://github.com/awgu, https://github.com/mrshenli