pytorch
a0ba7fb4 - Precompute entries in dispatch tables (#40512)

Commit
5 years ago
Precompute entries in dispatch tables (#40512) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40512 Fixes https://github.com/pytorch/pytorch/issues/32454 The heart of this diff is changing this: ``` inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, DispatchKey dispatchKey) c nst { const KernelFunction* backendKernel = dispatchTable.lookup(dispatchKey); if (nullptr != backendKernel) { return *backendKernel; } const auto& backendFallbackKernel = backendFallbackKernels_[dispatchKey]; if (backendFallbackKernel.isValid()) { return backendFallbackKernel; } const KernelFunction* catchallKernel = dispatchTable.lookupCatchallKernel(); if (C10_LIKELY(nullptr != catchallKernel)) { return *catchallKernel; } reportError(dispatchTable, dispatchKey); } ``` to this: ``` const KernelFunction& OperatorEntry::lookup(DispatchKey k) const { const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)]; if (C10_UNLIKELY(!kernel.isValid())) { reportError(k); } return kernel; } ``` The difference is that instead of checking a bunch of places to find the right kernel to use for an operator, all of the operators are precomputed into dispatchTable_ itself (so you don't have to consult anything else at runtime.) OperatorEntry::computeDispatchTableEntry contains that computation (which is exactly the same as it was before.) By doing this, we are able to substantially simplify many runtime components of dispatch. The diff is fairly large, as there are also some refactors interspersed with the substantive change: - I deleted the DispatchTable abstraction, folding it directly into OperatorEntry. It might make sense to have some sort of DispatchTable abstraction (if only to let you do operator[] on DispatchKey without having to cast it to integers first), but I killed DispatchTable to avoid having to design a new abstraction; the old abstraction wasn't appropriate for the new algorithm. - I renamed OperatorEntry::KernelEntry to AnnotatedKernel, and use it to store backend fallbacks as well as regular kernel registrations (this improves error messages when you incorrectly register a backend fallback twice). - I moved schema_ and debug_ into an AnnotatedSchema type, to make the invariant clearer that these are set together, or not at all. - I moved catch-all kernels out of kernels_ into its own property (undoing a refactor I did before). The main reason I did this was because our intended future state is to not have a single catch-all, but rather possibly multiple catch-alls which fill-in different portions of the dispatch table. This may change some more in the future: if we allow registrations for multiple types of catch alls, we will need a NEW data type (representing bundles of dispatch keys) which can represent this case, or perhaps overload DispatchKey to also record these types. The key changes for precomputation: - OperatorEntry::updateDispatchTable_ is now updated to fill in the entry at a DispatchKey, considering both kernels (what it did before) as well as catch-all and backend fallback. There is also OperatorEntry::updateDispatchTableFull_ which will update the entire dispatch table (which is necessary when someone sets a catch-all kernel). OperatorEntry::computeDispatchTableEntry holds the canonical algorithm specifying how we decide what function will handle a dispatch key for the operator. - Because dispatch table entry computation requires knowledge of what backend fallbacks are (which is recorded in Dispatcher, not OperatorEntry), several functions on OperatorEntry now take Dispatcher as an argument so they can query this information. - I modified the manual boxing wrapper invariant: previously, kernels stored in kernels_ did NOT have manual boxing wrappers and this was maintained by DispatchTable. Now, we just ALWAYS maintain manual boxing wrappers for all KernelFunctions we store. - DispatchKeyExtractor is greatly simplified: we only need to maintain a single per-operator bitmask of what entries are fallthrough (we don't need the global bitmask anymore). - Introduced a new debugging 'dumpComputedTable' method, which prints out the computed dispatch table, and how we computed it to be some way. This was helpful for debugging cases when the dispatch table and the canonical metadata were not in sync. Things that I didn't do but would be worth doing at some point: - I really wanted to get rid of the C10_UNLIKELY branch for whether or not the KernelFunction is valid, but it looks like I cannot easily do this while maintaining good error messages. In principle, I could always populate a KernelFunction which errors, but the KernelFunction needs to know what the dispatch key that is missing is (this is not passed in from the calling convention). Actually, it might be possible to do something with functors, but I didn't do it here. - If we are going to get serious about catchalls for subsets of operators, we will need to design a new API for them. This diff is agnostic to this question; we don't change public API at all. - Precomputation opens up the possibility of subsuming DispatchStub by querying CPU capability when filling in the dispatch table. This is not implemented yet. (There is also a mild blocker here, which is that DispatchStub is also used to share TensorIterator configuration, and this cannot be directly supported by the regular Dispatcher.) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D22236352 Pulled By: ezyang fbshipit-source-id: d6d90f267078451816b1899afc3f79737b4e128c
Author
Parents
Loading