Add option to use ninja to compile ahead-of-time cpp_extensions
Background
^^^^^^^^^^
Previously, ninja was used to compile+link inline cpp_extensions and
ahead-of-time cpp_extensions were compiled with distutils. This PR adds
the ability to compile (but not link) ahead-of-time cpp_extensions with ninja.
The main motivation for this is to speed up cpp_extension builds: distutils
does not make use of parallelism. With this PR, using the new option, on my machine,
- torchvision compilation goes from 3m43s to 49s
- nestedtensor compilation goes from 2m0s to 28s.
User-facing changes
^^^^^^^^^^^^^^^^^^^
I added a `use_ninja` flag to BuildExtension. This defaults to
`False`. When `use_ninja` is True:
- it will attempt to use ninja.
- If we're on windows, then this throws a warning and falls back to distutils.
(this PR doesn't add windows support for this functionality, but it can
be a future task).
- If we can't find ninja, then we hard error.
I wasn't sure whether or not to set `use_ninja=True` as the default. If
we make that the default, then it instead needs to do a "best effort"
and have an automatic fallback path to avoid being bc-breaking. Testing
such a change is kind of difficult. I'm happy to figure out the
design/behavior of this in this PR or in a follow-up.
Implementation Details
^^^^^^^^^^^^^^^^^^^^^^
This PR makes this change in two steps. Please me know if it would be
easier to review this if I split this up into a stacked diff.
Those changes are:
1) refactor _write_ninja_file to separate the policy (what compiler flags
to pass) from the mechanism (how to write the ninja file and do compilation).
2) call _write_ninja_file and _run_ninja_build while building
ahead-of-time cpp_extensions. These are only used to compile objects;
distutils still handles the linking.
Change 1: refactor _write_ninja_file to seperate policy from mechanism
- I split _write_ninja_file into: _write_ninja_file and
_write_ninja_file_to_build_library
- I renamed _build_extension_module to _run_ninja_build
Change 2: Call _write_ninja_file while building ahead-of-time
cpp_extensions
- _write_ninja_file_and_compile_objects calls _write_ninja_file to only
build object files.
- We monkey-patch distutils.CCompiler.compile to call
_write_ninja_files_and_compile_objects
- distutils still handles the linking step. The linking step is not a
bottleneck so it was not a concern.
- This change only works on unix-based systems. Our code for windows
goes down a different codepath and I did not want to mess with that.
- If a system does not support ninja, we raise a warning and fall back
to the original compilation path.
Test Plan
^^^^^^^^^
Adhoc testing
- I built torchvision using pytorch master and printed out the build
commands. Next, I used this branch to build torchvision and looked at
the ninja file. I compared the ninja file with the build commands and
asserted that they were functionally the same.
- I repeated the above for pytorch/nestedtensor.
PyTorch test suite
- I added cpp extension tests to test BuildExtension with
`use_ninja=False` and `use_ninja=True`.
- This involves having a duplicate of the cpp extension files. One of
the copies is used to build with `use_ninja=False`, the other copy is
used to build `use_ninja=True`. I am not sure if it is possible to
deduplicate these.
- Ran the cpp extension tests. `python run_test.py -i cpp_extension -v`
Help wanted and/or future work
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- API: should we have `use_ninja=True` and implement an automatic
fallback?
- Testing: How do we clean up the tests?
[ghstack-poisoned]