pytorch
f556d735 - [torch] Implement aten::native_batch_norm.out for CPU (#88604)

Commit
2 years ago
[torch] Implement aten::native_batch_norm.out for CPU (#88604) Summary: Implement `native_batch_norm.out` for CPU. Reuses the main logic for `native_batch_norm` but extract out the Tensor creation logic for outputs. There are 3 outputs: `output`, `save_mean` and `save_var`. `batch_norm_cpu` calls `batch_norm_cpu_update_stats_template` to get `save_mean` and `save_var`, and then calls into `batch_norm_cpu_transform_input_template` which initializes `output`. In the implementation of `batch_norm_cpu_out`, I did the following: * Let `batch_norm_cpu_transform_input_template` to take another argument `output`, ask the call sites to pass in a output Tensor. * Overload `batch_norm_cpu_update_stats_template` to take `save_mean` and `save_var`, ask the call sites to pass in those Tensors. * In `batch_norm_cpu_out`, pass `output`, `save_mean` and `save_var` all the way to our new `batch_norm_cpu_transform_input_template` and `batch_norm_cpu_update_stats_template`. * In `batch_norm_cpu`, prepare for these outputs and call `batch_norm_cpu_out`. Test Plan: Enable unit tests for `native_batch_norm.out`. Differential Revision: D40992036 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88604 Approved by: https://github.com/iseeyuan, https://github.com/jjsjann123
Author
Committer
Parents
Loading