DeepSpeed
4559dadd - Cache metadata for TP activations and grads (#4360)

Commit
2 years ago
Cache metadata for TP activations and grads (#4360) PartitionedTensor.from_meta will cause device to host synchronization when reading the meta tensor in meta = meta.tolist() Added cpu cache for the meta tensor to avoid this synchronization in every activation and grad communication between the ranks. The meta tensor is assumed to be static since activation shape must be static. The user must call reset_activation_shape if any of the dimentions change. Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Author
Parents
Loading