DeepSpeed
5efb24ac - Merging AutoSP into DeepSpeed (#7860)

Commit
23 days ago
Merging AutoSP into DeepSpeed (#7860) # AutoSP: Unlocking Long-Context LLM Training Via Compiler-Based Sequence Parallelism ## Overview AutoSP is a compiler optimization pass that shards inputs along the sequence dimension and enables Ulysses styled sequence parallelism while preventing graph breaks during `torch.compile()`. All the passes operate at the Torch IR on the forward graph. ## API Design ### User-Facing Entry Point: `prepare_autosp_inputs()` Users must explicitly call this function to prepare inputs for AutoSP compilation: ```python def prepare_autosp_inputs( input_id: torch.Tensor, label_id: torch.Tensor, position_id: torch.Tensor = None, attention_mask: torch.Tensor = None, seq_dim: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ``` **Purpose**: Symbolize sequence dimension and annotate tensors for identification. **Operations**: 1. Mark sequence dimension as dynamic using `torch._dynamo.decorators.mark_dynamic()` 2. Attach metadata tags for tensor identification for auto-sharding: - `input_id.tag = constants.INPUT_ID_KEY` - `label_id.tag = constants.LABEL_ID_KEY` - `position_id.tag = constants.POSITION_ID_KEY` (if provided) **Rationale**: PyTorch's FX graph tracer requires explicit annotation of data-dependent dimensions. Marking the sequence dimension as dynamic prevents symbolic shape propagation from losing dimension information through reshape/view operations. ## Compilation Passes ### Pass 1: `pass_shard_seq_dim()` **Objective**: Propagate sharded sequence dimension to all consumers. **Algorithm**: 1. Extract symbolic sequence dimension from `input_id` shape metadata 2. Locate the symbolic dimension node in the FX graph 3. Create a floor-divide node: `seq_dim / world_size` 4. Perform worklist-based graph traversal to find all direct and indirect consumers of input node, label node and position id node. 5. Replace symbolic dimension references with sharded dimension in consumer nodes **Rationale**: Reshapes and views that consume the sequence dimension as an argument do not get updated during propagation of symbolic shapes. This pass explicitly rewires the computation graph to use sharded dimensions, enabling proper shape inference downstream. ### Pass 2: `pass_shard_input_ids()` / `pass_shard_label_ids()` / `pass_shard_position_ids()` **Objective**: Insert slicing operations after input tensors. **Implementation**: Call `shard_tensor_node()` utility which inserts slice operations. Each rank retains only the portion of the tensor corresponding to its sequence partition and drops the remaining buffer. **Note on `attention_mask`**: Not sharded because it applies to the full sequence length, not the partitioned dimension. ### Pass 3: `pass_insert_attention_all_to_all()` **Objective**: Insert all-to-all collectives around attention (Ulysses styled) to avoid graph breaks during compilation. **Algorithm**: 1. Identify all SDPA (Scaled Dot-Product Attention) nodes in the graph 2. For each SDPA node with inputs Q, K, V, after each of Q, K, V: insert A2A scatter heads (dim=1), gather sequence (dim=2) 3. Insert A2A after thre attention output O: scatter sequence (dim=2), gather heads (dim=1) **Graph Rewrite Example**: ``` Q [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H] K [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H] V [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H] | SDPA | O [B, N/P, S, H] --A2A(scatter_seq,gather_heads)--> [B, N, S/P, H] ``` **Current support**: Currently only supports `torch.nn.functional.scaled_dot_product_attention()`. Composite attention patterns require additional pattern matching logic. ### Pass 4: `pass_propagate_shapes()` **Objective**: Compute static shapes for all nodes using fake tensor execution. **Implementation**: 1. Create `ShapeEnv` for symbolic dimension tracking 2. Construct `FakeTensorMode` with the shape environment 3. Execute `FakeTensorProp.propagate()` to compute shape metadata ### Pass 5: `pass_canonicalize()` **Objective**: Finalize graph representation. **Operations**: 1. `eliminate_dead_code()`: Remove unused operations 2. `lint()`: Validate graph structure 3. `recompile()`: Regenerate compiled representation ## Execution Order ``` prepare_autosp_inputs() ↓ pass_shard_seq_dim ↓ pass_shard_input_ids ↓ pass_shard_label_ids ↓ pass_shard_position_ids ↓ pass_insert_attention_all_to_all ↓ pass_propagate_shapes ↓ pass_canonicalize ↓ pass_selective_activation_checkpointing ``` ## Memory savings AutoSP adds some heuristics to torch.compile's partitioniner which splits the joint graph into the forward and backward graph. Matmul and related ops are not checkpointed since recomputing them is much cheaper compared to the attention op, while reducing the peak active memory. ## Reducing gradients across ranks AutoSP requires an all-reduce to reduce the gradients across ranks. This is automatically called by DeepSpeed's engine [here](https://github.com/deepspeedai/DeepSpeed/blob/93524c8931799a7631a2321d7ef4afaff6b6e54b/deepspeed/runtime/engine.py#L2433) ## Known Limitations 1. **Attention Pattern Matching**: Only `torch.nn.functional.scaled_dot_product_attention()` is supported. Fused attention implementations require pattern-specific handling. 2. **No Graph Break Requirement**: AutoSP will fail if there are graph breaks because use-def chains are lost and it becomes tricky to propagate auto-sharding information across graph modules. ## Example DeepSpeedExample PR: https://github.com/deepspeedai/DeepSpeedExamples/pull/999 --------- Signed-off-by: Neel Dani <neeldani98@gmail.com> Signed-off-by: Ahan Gupta <ahangupta.96@gmail.com> Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Ahan Gupta <ahangupta.96@gmail.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
Author
Parents
Loading