pytorch
4a1633ca - [Inductor] GEMM Shape Padding Optimization (#90425)

Commit
2 years ago
[Inductor] GEMM Shape Padding Optimization (#90425) Summary: Optimize the shape padding in the following perspectives: - Add BFloat16 support for AMP training and Float16 support for inference - Optimize microbenchmark to avoid peak memory issue, and include profiling memory ops to make more accurate decision - Add a flag to turn off/on padding dims N and M in `torch.bmm` due to expensive memory copy of `.contiguous` to avoid peak memory issues in internal models Test Plan: CI Differential Revision: D41724868 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90425 Approved by: https://github.com/jianyuh
Author
Committer
Parents
Loading