pytorch
50054b1a - [AOTInductor] ProxyExecutor support ReinterpretView inputs (#110451)

Commit
1 year ago
[AOTInductor] ProxyExecutor support ReinterpretView inputs (#110451) Summary: See wrapper.codegen_reinterpret_view(), it return a temporary handle for tensor, which has following problem. ``` # NB, the return handle here represents a temporary tensor, which will be automatically # released. # Here's a sample usage in the cpp wrapper code: # ``` # aoti_torch_addmm_out( # buf1, # arg1_1, # RAIIAtenTensorHandle(tmp_tensor_handle_0), # buf0, # 1L, # 1L)); # ``` # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. # This could be problematic when it's used in a different pattern, for example: # ```` # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; # aoti_torch_proxy_executor_call_function(..., tensor_args); # ```` # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter # kernel call. return f"RAIIAtenTensorHandle({tmp_name})" ``` As a result, ProxyExecutor would generate following code, which cause invalid memory access. Before: ``` // Source Nodes: [fn_with_tuple_output], Original ATen: [fb.fn_with_tuple_output] AtenTensorHandle tmp_tensor_handle_2; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor(buf3, 2, int_array_0, int_array_1, 0L, &tmp_tensor_handle_2)); ... AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; int64_t int_args[] = {1}; aoti_torch_proxy_executor_call_function(proxy_executor, 1, 1, int_args, 3, tensor_args); buf3.reset(); ``` With fix in this diff, ProxyExecutor generates following code After: ``` // Source Nodes: [fn_with_tuple_output], Original ATen: [fb.fn_with_tuple_output] AtenTensorHandle tmp_tensor_handle_2; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor(buf3, 2, int_array_0, int_array_1, 0L, &tmp_tensor_handle_2)); ... aoti_torch_proxy_executor_call_function(proxy_executor, 1, 1, std::vector<int64_t>{1}.data(), 3, std::vector<AtenTensorHandle>{RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}.data()); buf3.reset(); ``` I am not exactly a big fan of such `std::vector{...}.data()` for creating a temp array, but I can't think of another fix. Test Plan: buck2 run mode/dev-nosan deeplearning/aot_inductor/test:test_custom_ops Reviewed By: desertfire Differential Revision: D49758764 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110451 Approved by: https://github.com/desertfire
Author
Committer
Parents
Loading