PR #28157: Warn on excessive captured constants
Imported from GitHub PR https://github.com/jax-ml/jax/pull/28157
One of the most common modes by which users may have extended lowering times is by unintentionally capturing, rather than tracing, a large number of constants, e.g. capturing the model weights as part of the computation.
To address this we introduce two new flags
- `jax_captured_constants_warn_bytes` defaults to `2 * 10 ** 9` (2GB). The number of total bytes of captured constants before warning is issued. (Note that the maximum size binary XLA can serialize for the compilation cache is 2GB.)
- `jax_captured_constants_report_frames` defaults to `0`. If a warning is issued, how many stack frames to report for each constant. Defaults to 0, which means we don't generate the report by default. Reports all frames if set to `-1`.
Both message are returned by using `warnings.warn`. The envisioned workflow for debugging captured constants is as follows:
1. The user is alerted to the problem by the initial (by default). This looks something like:
```
UserWarning: A large amount of constants were captured during lowering (125.00GB total).
If this is intentional, disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1.
To obtain a report of where these constants were encountered, set JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1.
```
2. The user sets the `JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1` to obtain debugging information to locate the source of the constants. Setting `JAX_CAPTURED_CONSTANTS_REPORT_FRAMES` to a small positive integer will return a suffix of the number of captured frames. Upon rerunning the code, the report generated will be something like this:
```
UserWarning: A large amount of constants were captured during lowering (125.00GB total).
If this is intentional, disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1.
The subsequent report may be disabled by setting JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=0.
Largest 5 allocation(s):
Constant <class 'numpy.ndarray'>, float32[1439,721,720], 5.00GB captured at:
/home/user/project/main.py:193 (<module>)
/home/user/project/main.py:156 (run_export)
/home/user/project/main.py:147 (run_forward_exported)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:98 (tree_map_over_nonscalars)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:97 (g)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:661 (g)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:266 (inverse_transform)
Constant <class 'numpy.ndarray'>, float32[1439,721,720], 5.00GB captured at:
/home/user/project/main.py:193 (<module>)
/home/user/project/main.py:156 (run_export)
/home/user/project/main.py:147 (run_forward_exported)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:98 (tree_map_over_nonscalars)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/pytree_utils.py:97 (g)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:661 (g)
/home/user/project/.venv/lib/python3.10/site-packages/dinosaur/spherical_harmonic.py:266 (inverse_transform)
```
and so fourth. The report is hard coded to only report the the top 5 (ordered by `nbytes`) largest constants.
3. The user can then set a breakpoint at last line listed, and inspect the operands to find the problematic constant by the shape and type we report.
The reason for the two stage process is it is very easy and cheap to check the number of bytes each time we lower a function. However generating the report involves a traversal of the Jaxpr, and so is not.
For Googlers [b/403532544](http://b/403532544#comment26)
Merging this change closes #28157
PiperOrigin-RevId: 755142222