pytorch
ac6b5bea - [torch][segment_reduce] Add support for mean reduction (cpu) (#59521)

Commit
3 years ago
[torch][segment_reduce] Add support for mean reduction (cpu) (#59521) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59521 This diff is adding support for mean reduction for CPU (fwd + bckwd). Will add cuda implementation in subsequent PR. We are using "cub::DeviceSegmentedReduce" for other aggregation, trying to see how to support mean or will write custom kernel for it. Next Steps: - cuda support for mean - 2d data input support - more testing - benchmarking Test Plan: updated unit test. Still relying on manual data for ease of debugging. Will add more tests that covers edge cases once major features are complete. Reviewed By: ngimel Differential Revision: D28922547 fbshipit-source-id: 2fad53bbad2cce714808ff95759cbdbd45bb4ce6
Author
Serhat Yilmaz
Parents
Loading