Fix layernorm and softmax axis after upstream (#17255)
### Fix layernorm and softmax axis after upstream
For Gather (the slicing is a scalar), the output rank is small than its
inputs.
When we upstream this kind of Gather before softmax or layernorm, we
should also update the axis attribute.
Otherwise, the axis might be out-of-date and incorrect for the updated
rank.
```
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
raise exception
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward
self._build_graph(graph_transformer_config)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
result = func(graph_execution_manager, *args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
result = func(graph_execution_manager, *args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 361, in _build_graph
super()._build_graph(graph_transformer_config)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 184, in _build_graph
self._graph_builder.build(config)
RuntimeError: /onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (Softmax_2904) Op (Softmax) [ShapeInferenceError] 'axis' must be in [-3 , 2]. Its actual value is: 3
```