pytorch
06762b47 - Fix distributions.Categorical.sample bug from .view() (#23328)

Commit
5 years ago
Fix distributions.Categorical.sample bug from .view() (#23328) Summary: This modernizes distributions code by replacing a few uses of `.contiguous().view()` with `.reshape()`, fixing a sample bug in the `Categorical` distribution. The bug is exercised by the following test: ```py batch_shape = (1, 2, 1, 3, 1) sample_shape = (4,) cardinality = 2 logits = torch.randn(batch_shape + (cardinality,)) dist.Categorical(logits=logits).sample(sample_shape) # RuntimeError: invalid argument 2: view size is not compatible with # input tensor's size and stride (at least one dimension spans across # two contiguous subspaces). Call .contiguous() before .view(). # at ../aten/src/TH/generic/THTensor.cpp:203 ``` I have verified this works locally, but I have not added this as a regression test because it is unlikely to regress (the code is now simpler). Pull Request resolved: https://github.com/pytorch/pytorch/pull/23328 Differential Revision: D16510678 Pulled By: colesbury fbshipit-source-id: c125c1a37d21d185132e8e8b65241c86ad8ad04b
Author
Parents
Loading