pytorch
568db1b4 - [dtensor] Relax condition for _split_tensor() (#101218)

Commit
1 year ago
[dtensor] Relax condition for _split_tensor() (#101218) When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (https://github.com/pytorch/pytorch/pull/98722). Therefore, we no longer needs this assert. For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows: ``` rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)]) rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)]) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/101218 Approved by: https://github.com/wanchaol
Author
Committer
Parents
Loading