Merge GatherToSplitFusion and #19218 to a General Fusion (#19600)
#19218 tried to fuse Gather/Slice to Split, but the logic has problem.
Scalar value or 1-dim value of indices in Gather node will produce
different result, scalar value will produce a result tensor by removing
the axis dim, will 1-dim indices value will keep that dim, even when the
dim value is 1. For example,
Node
|-> Gather(indices=[0], axis=axis)
|-> Gather(indices=[1], axis=axis)
|-> Slice(index=2, axis=axis)
is same as
Node
|-> Split(axis=axis)
But
Node
|-> Gather(indices=0, axis=axis)
|-> Gather(indices=1, axis=axis)
|-> Slice(index=2, axis=axis)
is same as
Node
|-> Split(axis=axis)
||-> Squeeze(axis=axis)
||-> Squeeze(axis=axis)
||->
Previous PR doesn't take such case related to Squeeze/Unsqueeze into
account.
This PR merges #19218 and GatherToSplitFusion to a general fusion, which
relaxes the limit the number of Gather and Slice node number, check all
Gather and Slice consumers, if the indices of Gather and start/end of
Slice can cover the specific dim of the input tensor, then we can fuse
them to a Split, and adding Squeeze if necessary according to the dim
count of the indices tensor in Gather.
@rui-ren, please check if the fix can still be applied to your model.