pytorch
cc6a51c9 - added shape checking to WeightedRandomSampler (#78585)

Commit
3 years ago
added shape checking to WeightedRandomSampler (#78585) Fixes #78236 An erronously shaped weights vector will result in the following output ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) ~/datarwe/pytorch/torch/utils/data/sampler.py in <module> [274](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=273) WeightedRandomSampler([1,2,3], 10) ----> [275](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=274) WeightedRandomSampler([[1,2,3], [4,5,6]], 10) ~/datarwe/pytorch/torch/utils/data/sampler.py in __init__(self, weights, num_samples, replacement, generator) [192](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=191) weights = torch.as_tensor(weights, dtype=torch.double) [193](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=192) if len(weights.shape) != 1: --> [194](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=193) raise ValueError("weights should be a 1d sequence but given " [195](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=194) "weights have shape {}".format(tuple(weights.shape))) [196](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=195) ValueError: weights should be a 1d sequence but given weights have shape (2, 3) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78585 Approved by: https://github.com/NivekT, https://github.com/ejguan
Author
Oliver Sellwood
Committer
Parents
Loading