[GPU] Add the capability for KV cache to update past KV (#33114)
### Details:
This PR is to recognize the pattern of ScatterElementUpdate+Slice
node(blue nodes in the picture below) and fuse them into multi-stages
KVCache node. Besides, past_seq_len from onnx GQA which serves for
correcting the length of KV Cache is missing in decomposition of onnx
operator, it is added in the PR to make sure it is benefited from the
new capability of KVCache.
After fusion, two related changes happened.
1. ScatteElementUpdate is handled by adding reorder_stage to execute
ScatteElementUpdate kernel
2. Slice is handled by in-place crop by updating the data padding of
variableState.
The picture below shows the graph changes before and after fusion.
<img width="2503" height="1183" alt="image"
src="https://github.com/user-attachments/assets/4743da6b-edf9-4fe9-b7e2-158c81fc7abd"
/>
### Motivation and Context
The target application leverages tree-based speculative decoding to
accelerate LLM inference. This technique requires frequent manipulation
of past KV cache states (e.g. trimming, reordering). This is because
only a single branch of the speculative draft tree is accepted after
verification.
The current KV Cache API available is OV is very slow which cannot meet
customer requirements. Details in
[CVS-174809](https://jira.devtools.intel.com/browse/CVS-174809). As OV
team suggested, the only way to support reorder feature is to add
specific nodes in the original graph. This PR is to recognize the
pattern of added nodes and fuse them into multi-stages KVCache node to
be more performant.
### Tickets:
[CVS-176367](https://jira.devtools.intel.com/browse/CVS-176367)
### Related PR
#32708
---------
Co-authored-by: Dvoretckii, Mikhail <mikhail.dvoretckii@intel.com>
Co-authored-by: czekun <chen.zekun@intel.com>