transformers
d55f0350 - Fix T5Attention shape mismatch under Tensor Parallelism (#45109)

Commit
25 days ago
Fix T5Attention shape mismatch under Tensor Parallelism (#45109) * Fix T5Attention shape mismatch under Tensor Parallelism T5Attention.forward hard-codes n_heads and inner_dim in view() calls. When ColwiseParallel shards q/k/v projections, the output dim becomes inner_dim / tp_size, but n_heads stays unchanged, causing RuntimeError on view(). Move the -1 (auto-infer) from the seq_length dim to the n_heads dim so that num_heads is derived from the actual tensor shape. * Run make fix-repo: propagate TP view fix to copied models Sync formatting and view() changes to longt5, mt5, pop2piano, switch_transformers, and udop via Copied-from mechanism. * Refactor view() calls to use shape tuples per review Define q_input_shape and kv_input_shape tuples before view() calls, following the modern style used in BertCrossAttention. * Chain view() calls per reviewer suggestion
Author
Parents
Loading