onnxruntime
4fec745e - Fix CrossEntropyLoss block to support multi-output models (#28232)

Commit
28 days ago
Fix CrossEntropyLoss block to support multi-output models (#28232) ## Summary - `artifacts.generate_artifacts(..., loss=LossType.CrossEntropyLoss)` no longer aborts with `i < node_->OutputDefs().size()` when the base model has multi-dimensional outputs. - The `SoftmaxCrossEntropyLoss` op produces two outputs (`loss`, `log_prob`); the second was being dropped by graph optimizers because it had no `value_info` entry, leaving the gradient builder to dereference a missing output def via `O(1)`. ## Motivation Fixes #22465. Users hit a hard C++ assertion when training models like DistilBERT whose forward graph emits a multi-dimensional last-hidden-state tensor. The same pattern appears for any seq2seq / LM training setup that pipes a 3-D output into `CrossEntropyLoss`. This is a Python-only change scoped to the `onnxblock` training-artifacts API; the core inference engine is unaffected. ## Changes - `orttraining/orttraining/python/training/onnxblock/loss/loss.py` — after appending the `SoftmaxCrossEntropyLoss` node, register a `value_info` entry for `log_prob_output_name` so its output def survives shape inference and graph cleanup. Idempotent — guarded against duplicate entries. - `orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py` — new `test_crossentropy_loss_multi_output_model` builds a 3-D output toy model, calls `generate_artifacts` with `LossType.CrossEntropyLoss`, and asserts the saved `training_model.onnx` retains both outputs on the SCE node. ## Test Plan - New test exercises the previously-failing path: `python -m pytest orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py::test_crossentropy_loss_multi_output_model -v` - Existing CE coverage: `python -m pytest orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py -k crossentropy -v` - `lintrunner` clean on the diff. Fixes #22465
Author
Parents
Loading