Reduce memory overhead of categorical.sample (#34900)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/34714 (using the discussed solution). Thanks to jjabo for flagging and suggesting this.
Instead of expanding `probs` to prepend `sample_shape`, it is better to use the `num_samples` argument to `torch.multinomial` instead, which is faster and consumes lesser memory.
Existing tests should cover this. I have profiled this on different inputs and the change results in faster `.sample` (e.g. 100X faster on the example in the issue), or at worst is similar to what we have now with the default `sample_shape` argument.
cc. fritzo, alicanb, ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34900
Differential Revision: D20499065
Pulled By: ngimel
fbshipit-source-id: e5be225e3e219bd268f5f635aaa9bf7eca39f09c