pytorch
3f1c8094 - [static runtime] port c2 argmin kernel (#63632)

Commit
4 years ago
[static runtime] port c2 argmin kernel (#63632) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63632 Local benchmarking with 1 input repeated 10k iter on 290331537_4 local net. Reduces argmin runtime by about 80% and and local net execution by about ~0.71-0.77ms. Before: ``` I0826 17:25:53.972786 1104614 PyTorchPredictorBenchLib.cpp:313] PyTorch run finished. Milliseconds per iter: 7.37599. Iters per second: 135.57 ``` ``` Static runtime ms per iter: 8.22086. Iters per second: 121.642 Time per node type: 4.13527 ms. 50.9157%. fb::sigrid_transforms_torch_bind (1 nodes, out variant) 0.868506 ms. 10.6935%. aten::argmin (1 nodes, out variant) ... ``` After: ``` I0826 17:17:54.165174 1064079 PyTorchPredictorBenchLib.cpp:313] PyTorch run finished. Milliseconds per iter: 6.66724. Iters per second: 149.987 ``` ``` Static runtime ms per iter: 7.68172. Iters per second: 130.179 Time per node type: 4.1452 ms. 54.0612%. fb::sigrid_transforms_torch_bind (1 nodes, out variant) 0.656778 ms. 8.56562%. fb::quantized_linear (8 nodes) 0.488229 ms. 6.36741%. static_runtime::to_copy (827 nodes, out variant) 0.372678 ms. 4.86042%. aten::argmin (1 nodes, out variant) ...Time per node type: 3.39387 ms. 53.5467%. fb::sigrid_transforms_torch_bind (1 nodes, out variant) 0.636216 ms. 10.0379%. fb::quantized_linear (8 nodes, out variant) 0.410535 ms. 6.47721%. fb::clip_ranges_to_gather_to_offsets (304 nodes, out variant) 0.212721 ms. 3.3562%. fb::clip_ranges_gather_sigrid_hash_precompute_v3 (157 nodes, out variant) 0.173736 ms. 2.74111%. aten::matmul (1 nodes, out variant) 0.150514 ms. 2.37474%. aten::argmin (1 nodes, out variant) ``` P447422384 Test Plan: Test with local replayer sending traffic to `ansha_perf_test_0819.test`, and compare outputs to jit interpreter. Start compute tier: ``` RUN_UUID=ansha_perf_test_0819.test.storage JOB_EXPIRE_TIME=864000 MODEL_ID=290331537_4 PREDICTOR_TAG= PREDICTOR_VERSION=405 PREDICTOR_TYPE=CPU ADDITIONAL_FLAGS="--enable_disagg_file_split=true --enable_adx=false --load_remote_file_locally=true --pytorch_predictor_static_runtime_whitelist_by_id=290331537" GFLAGS_CONFIG_PATH=sigrid/predictor/gflags/predictor_gflags_ads_perf_cpu_pyper SMC_TIER_NAME=sigrid.predictor.perf.ansha_per_test_0819.test.storage CLUSTER=tsp_rva ENTITLEMENT_NAME=ads_ranking_infra_test_t6 PREDICTOR_LOCAL_DIRECTORY= ICET_CONFIG_PATH= NNPI_COMPILATION_CONFIG_FILE= NUM_TASKS=1 NNPI_NUM_WORKERS=0 tw job start /data/users/ansha/fbsource/fbcode/tupperware/config/admarket/sigrid/predictor/predictor_perf_canary.tw ``` Start nnpi tier: ``` RUN_UUID=ansha_perf_test_0819.test JOB_EXPIRE_TIME=247200 MODEL_ID=290331537_4 PREDICTOR_TAG= PREDICTOR_VERSION=343 PREDICTOR_TYPE=NNPI_TWSHARED ADDITIONAL_FLAGS="--torch_glow_min_fusion_group_size=30 --pytorch_storage_tier_replayer_sr_connection_options=overall_timeout:1000000,processing_timeout:1000000 --predictor_storage_smc_tier=sigrid.predictor.perf.ansha_perf_test_0819.test.storage --pytorch_predictor_static_runtime_whitelist_by_id=290331537" GFLAGS_CONFIG_PATH=sigrid/predictor/gflags/predictor_gflags_ads_perf_glow_nnpi_pyper_v1 SMC_TIER_NAME=sigrid.predictor.perf.ansha_perf_test_0819.test CLUSTER=tsp_rva ENTITLEMENT_NAME=ads_ranking_infra_test_t17 PREDICTOR_LOCAL_DIRECTORY= ICET_CONFIG_PATH= NNPI_COMPILATION_CONFIG_FILE= NUM_TASKS=1 NNPI_NUM_WORKERS=0 tw job start /data/users/ansha/fbsource/fbcode/tupperware/config/admarket/sigrid/predictor/predictor_perf_canary.tw ``` ```buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- StaticRuntime.IndividualOps_Argmin --print-passing-details``` Compared outputs to jit interpreter to check for no differences greater than 1e-3 (with nnc on) https://www.internalfb.com/intern/diff/view-version/136824794/ Reviewed By: hlu1 Differential Revision: D30445635 fbshipit-source-id: 048de8867ac72f764132295d1ebfa843cde2fa27
Author
Parents
Loading