update tracing codegen to use redispatch API (#52009)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52009
Taking advantage of the new `redispatch` API to clean up the codegen'd tracing kernels. Instead of directly interacting with the Dispatcher, the tracing kernels now just call the `redispatch` API directly.
One small benefit to this: hopefully the compiler is more likely to inline `Dispatcher::redispatch()`, since it's now used in fewer call-sites. After this change, the only places it's used are:
- the `redispatch` API (`RedispatchFunctions.cpp`)
- BackendSelect kernels.
One small complication: the redispatch API doesn't interact too well with `manual_cpp_binding` ops currently. I put a note with some thoughts in the comments.
Example tracing kernel before:
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const
torch::jit::Node* node = nullptr;
std::shared_ptr<jit::tracer::TracingState> tracer_state;
if (jit::tracer::isTracing()) {
tracer_state = jit::tracer::getTracingState();
at::Symbol op_name;
op_name = jit::Symbol::fromQualString("aten::add");
node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
jit::tracer::recordSourceLocation(node);
jit::tracer::addInputs(node, "self", self);
jit::tracer::addInputs(node, "other", other);
jit::tracer::addInputs(node, "alpha", alpha);
tracer_state->graph->insertNode(node);
jit::tracer::setTracingState(nullptr);
}
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::add", "Tensor")
.typed<Tensor (const Tensor &, const Tensor &, Scalar)>();
auto result =c10::Dispatcher::singleton()
.redispatch<Tensor, const Tensor &, const Tensor &, Scalar>(op,
if (tracer_state) {
jit::tracer::setTracingState(std::move(tracer_state));
jit::tracer::addOutput(node, result);
}
return result;
}
```
after: (note the lack of `Dispatcher::` calls)
```
Tensor add_Tensor(c10::DispatchKeySet ks, const Tensor & self, const Tensor & other, Scalar alpha)
torch::jit::Node* node = nullptr;
std::shared_ptr<jit::tracer::TracingState> tracer_state;
if (jit::tracer::isTracing()) {
tracer_state = jit::tracer::getTracingState();
at::Symbol op_name;
op_name = jit::Symbol::fromQualString("aten::add");
node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
jit::tracer::recordSourceLocation(node);
jit::tracer::addInputs(node, "self", self);
jit::tracer::addInputs(node, "other", other);
jit::tracer::addInputs(node, "alpha", alpha);
tracer_state->graph->insertNode(node);
jit::tracer::setTracingState(nullptr);
}
auto result =at::redispatch::add(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::D
if (tracer_state) {
jit::tracer::setTracingState(std::move(tracer_state));
jit::tracer::addOutput(node, result);
}
return result;
}
```
Test Plan: Imported from OSS
Reviewed By: ezyang
Differential Revision: D26356078
Pulled By: bdhirsh
fbshipit-source-id: bc96ca4c6d90903f1e265859160d4b13a8cc7310