Tests for fallback boxed dispatch (including TLS mode) (#26719)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26719
This PR adds a pair of tests for fallback boxed dispatch, exercising two different ways you might use it: (1) to implement a "wrapper" tensor type (e.g., LazyTensor, NestedTensor), and (2) to implement a toggleable "mode" (e.g., Profiling, Tracing). Both implement the most trivial possible implementations of their type: they "wrap" a real tensor simply forward along to the real implementation. This PR also adds the necessary feature support for toggleable mode, which is in the original generic dispatch abstraction design, but was not previously implemented. I had not originally intended to add this, but it turns out writing a new "mode" is a lot simpler than writing a "wrapper" type, so I ended up writing the mode version first.
General structure of the PR:
* Add two new testing tensor type ids, `TESTING_ONLY_GenericWrapperTensorId` and `TESTING_ONLY_GenericModeTensorId`, which our tests use. They might find other use in other tests if necessary.
* Add support for toggling the availability of `TESTING_ONLY_GenericModeTensorId`. Introduces a new thread local variable accessible by `tls_local_tensor_type_set()` which is considered as part of dispatch.
* The mode fallback is very simple: it increments a counter and then passes on the call to the underlying kernel by invoking the JIT.
* The wrapper fallback is more complex: it parses the arguments, unwrapping any wrapped tensor arguments, then invokes the JIT, and then rewraps the outputs.
The examples here are somewhat simplistic; there are a number of engineering improvements that could be applied. We could save these for later (landing this patch to get immediate testing), or incorporate them into this patch:
* `getOperator` is horrible. Bram Wasti and I discussed a plan for how to make this easier, by simply refactoring the JIT interface.
* `GenericWrapperTensorImpl` doesn't populate all of its fields accurately. Most notably, size is not setup correctly.
* `generic_wrapper_fallback` should handle tensor lists in arguments and returns properly.
One pitfall: fallback dispatch only works with non-c10 code. That's why I test using `batch_norm`.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D17549624
Test Plan: Imported from OSS
Pulled By: ezyang
fbshipit-source-id: 57dbdd8d6812a66082aa6db2934c8edcda340ea6