pytorch
d561aa94 - Adds normal prim, randn reference, and randn OpInfo (#85128)

Commit
2 years ago
Adds normal prim, randn reference, and randn OpInfo (#85128) This PR extends prims support for random operations by adding `prims.normal` and `refs.randn`. Note that in the future we may not want to model draws from distributions as their own prims. `prims.normal` accepts a shape and the mean and standard deviation of a normal distribution as numbers. This is distinct from `torch.normal` which takes two tensors so every generated datapoint can be drawn from a normal distribution with its own mean and standard deviation. To address this @ngimel and I expect to add `prims.normal_with_tensors`. The current `prims.normal` could be implemented using `prims.normal_with_tensors`, but we expect the case of two numbers is much more common, and that executors will likely want to specialize for it, anyway. In a follow-up PR I plan to add `refs.randn_like`, `prims.normal_with_tensors` (as mentioned above), and `refs.normal`. While writing this PR I noticed the following issues: - https://github.com/pytorch/pytorch/issues/85123 - https://github.com/pytorch/pytorch/issues/85121 The latter of which is prohibiting some testing. In future PRs I plan to add a prim for changing layout, add support for pinned memory, and improve support for testing tensor creation operators, likely with a TensorCreationOpInfo class. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85128 Approved by: https://github.com/ngimel
Author
Mike Ruberry
Committer
Parents
Loading