Allow upstream for Slice on single axis (#16410)
### Allow upstream for Slice on single axis
#### Benchmark on 8x32GB V100 + DeepSpeed
On Bloom560M model, there is 1.5% throughput gains on the same max batch
size 6.
```
torchrun --nproc_per_node=8 examples/onnxruntime/training/language-modeling/run_clm.py --model_name_or_path bigscience/bloom-560m --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 6 --per_device_eval_batch_size 1 --do_train --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused --max_steps 200 --logging_steps 1 --use_module_with_loss --deepspeed aml_ds_config_zero_1.json
```
##### Main branch
```
Total overhead: 38957ms where export takes 35493ms.
***** train metrics *****
epoch = 4.08
train_loss = 2.6841
train_runtime = 0:03:10.67
train_samples = 2318
train_samples_per_second = 50.348
train_steps_per_second = 1.049
throughput per gpu=4.08 * 2318 / (190.67 - 38.957) / 8(gpu) = 7.792 samples/second
```
##### This PR
```
Total overhead: 38649ms where export takes 34946ms.
***** train metrics *****
epoch = 4.08
train_loss = 2.6757
train_runtime = 0:03:08.08
train_samples = 2318
train_samples_per_second = 51.04
train_steps_per_second = 1.063
throughput per gpu=4.08 * 2318 / (188.08 - 38.649) / 8(gpu) = 7.911 samples/second
```
#### Benchmark on 4x16GB V100 + AutoCast
On Bloom560M model, there is 1.8% throughput gains on the same batch
size, 24% gains with corresponding maximum batch size.
Also it allow ORT run bigger batch size (from 3 to 4) on following
recipe.
```
torchrun --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_clm.py --model_name_or_path bigscience/bloom-560m --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 3 --per_device_eval_batch_size 1 --do_train --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused --max_steps 200 --logging_steps 1 --use_module_with_loss
```
##### Main branch
```
Total overhead: 4789ms where export takes 3798ms.
***** train metrics *****
epoch = 1.02
train_loss = 20.3338
train_runtime = 0:01:42.78
train_samples = 2343
train_samples_per_second = 23.349
train_steps_per_second = 1.946
throughput per gpu=1.02 * 2343 / (102.78 - 4.789) / 4(gpu) = 6.097 samples/second
```
##### This PR
```
Total overhead: 4608ms where export takes 3555ms.
***** train metrics *****
epoch = 1.02
train_loss = 20.3364
train_runtime = 0:01:40.87
train_samples = 2343
train_samples_per_second = 23.792
throughput per gpu=1.02 * 2343 / (100.87 - 4.608) / 4(gpu) = 6.207 samples/second
```
With this PR, also can run batch size 4 (main branch fails),
```
Total overhead: 4743ms where export takes 3698ms.
***** train metrics *****
epoch = 1.36
train_loss = 20.2096
train_runtime = 0:01:50.42
train_samples = 2343
train_samples_per_second = 28.979
train_steps_per_second = 1.811
throughput per gpu= 1.36 * 2343 / (110 - 4.743) / 4(gpu) =7.57 sample/second
```
#### Benchmark on 8x32GB V100 + AutoCast
On Bloom560M model, there is 0.9% throughput gains on the same batch
size, 8.6% gains with corresponding maximum batch size.
Also it allow ORT run bigger batch size (from 3 to 4) on following
recipe.
```
torchrun --nproc_per_node=8 examples/onnxruntime/training/language-modeling/run_clm.py --model_name_or_path bigscience/bloom-560m --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10 --per_device_train_batch_size 3 --per_device_eval_batch_size 1 --do_train --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --optim adamw_ort_fused --max_steps 200 --logging_steps 1 --use_module_with_loss
```
##### Main branch
```
Total overhead: 55259ms where export takes 51140ms.
***** train metrics *****
epoch = 2.06
train_loss = 2.8788
train_runtime = 0:02:36.65
train_samples = 2318
train_samples_per_second = 30.64
train_steps_per_second = 1.277
throughput per gpu=2.06 * 2318 / (156.65 - 55.259) / 8(gpu) = 5.887 samples/second
```
##### This PR
```
Total overhead: 55712ms where export takes 51418ms.
***** train metrics *****
epoch = 2.06
train_loss = 2.8696
train_runtime = 0:02:36.19
train_samples = 2318
train_samples_per_second = 30.731
train_steps_per_second = 1.28
throughput per gpu=2.06 * 2318/ (156.19 - 55.712) / 8(gpu) = 5.940 samples/second
```
With this PR, also can run batch size 4 (main branch fails),
```
Total overhead: 54238ms where export takes 49899ms.
***** train metrics *****
epoch = 2.74
train_loss = 2.7692
train_runtime = 0:02:58.47
train_samples = 2318
train_samples_per_second = 35.859
train_steps_per_second = 1.121
throughput per gpu= 2.74 * 2318 / (178.47 - 54.238) / 8(gpu) =6.391sample/second
```