transformers
bcecab89 - [distributed] Add param-level apply pass to apply_tensor_parallel

Commit
23 hours ago
[distributed] Add param-level apply pass to apply_tensor_parallel Run tensor parallelism in two passes: - Pass 1 (param-level): walk named_parameters() and, for styles in PARAM_ONLY_STYLES (grouped_gemm, moe_gate_up_colwise[_alt], moe_down_rowwise), shard the parameter directly via shard_parameters(). No forward hook. - Pass 2 (module-level): the existing named_modules() loop for forward hooks, now skipping PARAM_ONLY_STYLES. Param sharding runs first so module forward hooks (moe_experts_allreduce) see the already-sharded DTensor params. Also wire the EP-plan fallback so enable_expert_parallel uses model._ep_plan when no explicit plan is passed.
Author
Parents
Loading