pytorch
0a101bf8 - Improve name inference API by introducing a TensorName helper struct (#28904)

Commit
5 years ago
Improve name inference API by introducing a TensorName helper struct (#28904) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28904 Motivation ============ Before this PR, a core problem with writing name inference rules was that each rule needed to handle misalignment by themselves. A misaligned name occurs when we are matching None with a non-None name, but the non-None name already exists in the first tensor. For example, `A` is misaligned in `Tensor[A, None] + Tensor[None, A]`. Each op handled this in a custom way - align_from_right (used by broadcasting) handles misalignment - compute_matmul_outnames checks for misalignment across batch and feature dimensions. We can actually codify "misalignment" into something more rigorous by folding it into the definition of `match` and eliminate special handling of "misalignment". That is what this PR attempts to do. Approach ============ Definition: Two names in two tensors *match* if they are equal, or if at least one of them is a wildcard that can be *refined* to the other name. With this new definition, to check if two names match, we need to know about the names list that each name came from to determine if a wildcard can successfully be *refined* to the other name. For example, consider the following: ``` tensor: Tensor[A, None] other: Tensor[None, A]` ``` when unifying `tensor.names[-1]` with `other.names[-1]`, we see that `tensor.names[-1]` is None and `other.names[-1]` is A. Then we check to see if `tensor.names[-1]` can be refined to `A`; it can't be refined if there is already an `A` in `tensor.names`. Enter `TensorNames`. A TensorName represents a Dimname associated with some DimnameList (that came from a Tensor). `TensorNames` is a list of such TensorName objects with some helper functions attached. One can perform the following operations: - unify two `TensorName` objects - unify two `TensorNames` objects with right alignment. Plan ============ This PR changes `compute_matmul_outnames` to use `TensorNames` to demonstrate how they make writing name inference rules easier. In the future I'll convert other name inference rules to use `TensorNames` as well. Test Plan - run all tests Test Plan: Imported from OSS Differential Revision: D18270666 Pulled By: zou3519 fbshipit-source-id: 3ec96cc957747eb4cfe4ea17fd02ef3d8828a20c
Author
Parents
Loading