pytorch
6f84c5f0 - [SR] Generalize VarStackNodeWrapper (#71573)

Commit
3 years ago
[SR] Generalize VarStackNodeWrapper (#71573) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71573 Many ops (`gather_ranges_to_dense`, `sigrid_transforms`, etc) are implemented like this: ``` void op_out_(std::vector<Tensor>& output) { // actual op implementation } std::vector<Tensor> op() { std::vector<Tensor> outputs; // populate outputs with empty tensors op_out_(outputs) return outputs; } ``` This pattern is not ideal for ops that are fused with `ListUnpack` - it would be better if we wrote to the outputs directly. This diff extends the ideas from `VarStackNodeWrapper` to allow for this. The changes are: * `s/VarStackNodeWrapper/ProcessedNodeInputWrapper`. The old name was bad because the class is more general than the `VarStack` use case. Also moved the class to `processed_node_wrapper.h` * Add a `ProcessedNodeOutputWrapper`; it's essentially the same idea as `ProcessedNodeInputWrapper`, but it allows non-const access to the underlying tensors. * These classes are very similar, so CRTP is used to facilitate code re-use ghstack-source-id: 148825800 Test Plan: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- Stack` Reviewed By: swolchok Differential Revision: D33687965 fbshipit-source-id: 5fa0107211116867bb2b63968c126550d32fbea6 (cherry picked from commit 75c263d960a876e4db84c129e6cff2a770a3cd29)
Author
Mike Iovine
Committer
Parents
Loading