Precompute entries in dispatch tables (#40512)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40512
Fixes https://github.com/pytorch/pytorch/issues/32454
The heart of this diff is changing this:
```
inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, DispatchKey dispatchKey) c
nst {
const KernelFunction* backendKernel = dispatchTable.lookup(dispatchKey);
if (nullptr != backendKernel) {
return *backendKernel;
}
const auto& backendFallbackKernel = backendFallbackKernels_[dispatchKey];
if (backendFallbackKernel.isValid()) {
return backendFallbackKernel;
}
const KernelFunction* catchallKernel = dispatchTable.lookupCatchallKernel();
if (C10_LIKELY(nullptr != catchallKernel)) {
return *catchallKernel;
}
reportError(dispatchTable, dispatchKey);
}
```
to this:
```
const KernelFunction& OperatorEntry::lookup(DispatchKey k) const {
const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
if (C10_UNLIKELY(!kernel.isValid())) {
reportError(k);
}
return kernel;
}
```
The difference is that instead of checking a bunch of places to find the
right kernel to use for an operator, all of the operators are
precomputed into dispatchTable_ itself (so you don't have to consult
anything else at runtime.) OperatorEntry::computeDispatchTableEntry
contains that computation (which is exactly the same as it was before.)
By doing this, we are able to substantially simplify many runtime
components of dispatch.
The diff is fairly large, as there are also some refactors interspersed
with the substantive change:
- I deleted the DispatchTable abstraction, folding it directly into
OperatorEntry. It might make sense to have some sort of DispatchTable
abstraction (if only to let you do operator[] on DispatchKey without
having to cast it to integers first), but I killed DispatchTable to
avoid having to design a new abstraction; the old abstraction wasn't
appropriate for the new algorithm.
- I renamed OperatorEntry::KernelEntry to AnnotatedKernel, and use it
to store backend fallbacks as well as regular kernel registrations
(this improves error messages when you incorrectly register a backend
fallback twice).
- I moved schema_ and debug_ into an AnnotatedSchema type, to make the
invariant clearer that these are set together, or not at all.
- I moved catch-all kernels out of kernels_ into its own property
(undoing a refactor I did before). The main reason I did this was
because our intended future state is to not have a single catch-all,
but rather possibly multiple catch-alls which fill-in different
portions of the dispatch table. This may change some more in
the future: if we allow registrations for multiple types of
catch alls, we will need a NEW data type (representing bundles
of dispatch keys) which can represent this case, or perhaps
overload DispatchKey to also record these types.
The key changes for precomputation:
- OperatorEntry::updateDispatchTable_ is now updated to fill in the
entry at a DispatchKey, considering both kernels (what it did
before) as well as catch-all and backend fallback. There is also
OperatorEntry::updateDispatchTableFull_ which will update the
entire dispatch table (which is necessary when someone sets a
catch-all kernel). OperatorEntry::computeDispatchTableEntry
holds the canonical algorithm specifying how we decide what
function will handle a dispatch key for the operator.
- Because dispatch table entry computation requires knowledge of
what backend fallbacks are (which is recorded in Dispatcher,
not OperatorEntry), several functions on OperatorEntry now
take Dispatcher as an argument so they can query this information.
- I modified the manual boxing wrapper invariant: previously, kernels
stored in kernels_ did NOT have manual boxing wrappers and this
was maintained by DispatchTable. Now, we just ALWAYS maintain
manual boxing wrappers for all KernelFunctions we store.
- DispatchKeyExtractor is greatly simplified: we only need to maintain
a single per-operator bitmask of what entries are fallthrough
(we don't need the global bitmask anymore).
- Introduced a new debugging 'dumpComputedTable' method, which prints
out the computed dispatch table, and how we computed it to be some way.
This was helpful for debugging cases when the dispatch table and
the canonical metadata were not in sync.
Things that I didn't do but would be worth doing at some point:
- I really wanted to get rid of the C10_UNLIKELY branch for
whether or not the KernelFunction is valid, but it looks like
I cannot easily do this while maintaining good error messages.
In principle, I could always populate a KernelFunction which
errors, but the KernelFunction needs to know what the dispatch
key that is missing is (this is not passed in from the
calling convention). Actually, it might be possible to do
something with functors, but I didn't do it here.
- If we are going to get serious about catchalls for subsets of
operators, we will need to design a new API for them. This diff
is agnostic to this question; we don't change public API at all.
- Precomputation opens up the possibility of subsuming DispatchStub
by querying CPU capability when filling in the dispatch table.
This is not implemented yet. (There is also a mild blocker here,
which is that DispatchStub is also used to share TensorIterator
configuration, and this cannot be directly supported by the
regular Dispatcher.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D22236352
Pulled By: ezyang
fbshipit-source-id: d6d90f267078451816b1899afc3f79737b4e128c