[C2] Fix slowness of the ReshapeOp. (#33729)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33729
ReshapeOp is doing some useless movements of data between CPU and GPU, which results in crazy amount of kernel calls from this operator. Which makes this operator ridiculosly slow compared to BatchMatMul for cases of pretty cheap models (for example on some versions of GAT).
This diff is moving ReshapeOp to leverage CPU storage and reduce amount of kernel calls from num_dims + 3 calls for case of 3-D
tensor to 2 calls.
Test Plan:
Unit-tests are still passing.
TODO: perf testing
Reviewed By: akyrola
Differential Revision: D19659491
fbshipit-source-id: 2341b21e57208b988169f2df5fb598be3dc8acb2