Cache metadata for TP activations and grads (#4360)
PartitionedTensor.from_meta will cause device to host synchronization
when reading the meta tensor in
meta = meta.tolist()
Added cpu cache for the meta tensor to avoid this synchronization in
every activation and grad communication between the ranks. The meta
tensor is assumed to be static since activation shape must be static.
The user must call reset_activation_shape if any of the dimentions
change.
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>