jax
CI: 02/25/25 upstream sync
#240
Merged
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
276
Changes
View On
GitHub
CI: 02/25/25 upstream sync
#240
charleshofer
merged 276 commits into
rocm-main
from
ci-upstream-sync-127_1
Update XLA dependency to use revision
0409d1c3
[better_errors] Continue adding debug info to Jaxprs (step 8)
faa0ad6f
Merge pull request #26455 from gnecula:debug_info_jaxpr_8
1e2a5770
[Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
c7199fe8
Fix the string array test to not assume that there will always be exa…
6662ea96
Merge pull request #26447 from jakevdp:refactor-contractions
e14466a8
Enable pivoted QR on GPU via MAGMA.
b1b56ea0
Make sure we take libTPU version into account in the Pallas lowering
f1ab7514
[sharding_in_types] When caching mesh with axis_types, make sure the …
b4b4a98d
[Mosaic GPU] Remove old jaxlib version guards.
837418c6
refactor: move jnp.einsum impl into its own submodule
7ab7b214
Download and use `jax` wheels from GCS bucket for nightly/release tes…
93831bdd
[sharding_in_types] Make the typing checks and sharding rule checks a…
2d01df76
Merge pull request #25955 from tttc3:magma_qr
f7e2901e
Merge pull request #24910 from olupton:expect-pgle
5b697728
jax.numpy reductions: avoid upcast of f16 when dtype is specified by …
b5e7b60d
Part 1 of a new autodidax based on "stackless"
9145366f
Merge pull request #26403 from jakevdp:bf16-mean
4f1c67e6
Enable shardy batch partitionable FFI test.
9298018a
Fix doc string for PmapSharding
e231a35a
[xla:cpu] Implement XLA FFI handlers for CPU Jax callbacks.
8c685be6
[sharding_in_types] Fix some properties that assumed axis_types alway…
d58c3a47
Rename `sharding` argument to `out_sharding` for `lax.reshape`, `lax.…
1a62df1a
Create `_BaseMesh` so that properties can be shared between Mesh and …
0944e520
Merge pull request #26503 from garymm:patch-1
73c626d9
Merge pull request #26373 from jax-ml:autodidax-stackless
153a7cf9
[Mosaic TPU] Support bf16 div if HW does not directly support.
876668fa
[sharding_in_types] Error out when PartitionSpec is passed to APIs th…
15cd83ae
[sharding_in_types] Make `sharding` arg to ShapedArray kwarg only
3ec7a67e
Added stream annotation support via @compute_on('gpu_stream:#')
60f01846
Update XLA dependency to use revision
0941ec9f
[Mosaic GPU] Reorganize tests to make sure WGMMA tests are skipped on…
cbe102df
Update oryx for JAX debug_info
7161cad6
[pallas:mgpu] Fix and test multiple indexers where one is a dynamic s…
305e55f3
Fix Windows build for Mosaic GPU extension
a493df4d
[Mosaic GPU] Put all inline asm output constraints before input const…
f3f54dee
Merge pull request #25056 from chaserileyroberts:chase/compute_on_stream
af2fa9bc
[Mosaic GPU] Factor out arch specific Pallas Mosaic GPU tests
54fa1b9a
Migrated to mypy 1.14.1 with --allow_redefinition
194884d3
Skip pipeline mode args in tests with older libTPU
b0b1fa7d
Fix the error message to say `out_sharding` instead of `sharding` in …
2062e986
[Mosaic GPU] Simplify the collective barrier test to avoid GMEM atomics
5c7caa31
Merge pull request #26486 from superbobry:maint-2
5889fd0d
[Mosaic GPU] Add support for Blackwell MMA with n=512
845e3f8f
Rewrite generic LU pivots to permutation implementation using vmap in…
7efbda62
Only cache jax.Array._npy_value when a copy is required.
7f999298
Fix some busted batching rules in lax.linalg.
ea4e324f
Split NamedSharding into a separate file called named_sharding.py so …
229aa65a
Merge pull request #26472 from jakevdp:jnp-einsum
5ebb7eb5
refactor: move lax_numpy indexing routines to their own submodule
f750d0b8
[Mosaic GPU] Factor out Mosaic GPU dialect arch-specific tests
ae9389dc
[Mosaic GPU] Use gettatr to import version-specific dialect ops
ea5eb49a
Reorder top-level functions in lax.linalg, and add/expand docstrings.
c6c38fb8
Removed unused ``# type: ignore`` comments
a73456d5
Merge pull request #26461 from ROCm:run-less-rocm-tests
91c6e449
Move info!=0 logic into lax.linalg.tridiagonal lowering rule.
14afb732
Merge pull request #26509 from andportnoy:aportnoy/pallas-mosaic-gpu-…
f0cd1686
Merge pull request #26518 from superbobry:maint-2
60dcded2
[better_errors] Make it explicit that debug_info is not None.
a0812cd5
[Mosaic GPU] Define TMEMLayout without referring to the PTX guide
4a8023fe
Update XLA dependency to use revision
80dcb7b5
[pallas:triton] Added basic support for `lax.concatenate`
3162cc4d
[pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types a…
49ad2415
Make sure that tests don't change the state of the compilation cache
5ab8c5a8
Merge pull request #26522 from andportnoy:aportnoy/mosaic-gpu-test-sm90a
12d533f6
Merge pull request #26523 from andportnoy:aportnoy/mosaic-gpu-dialect…
4df59616
Remove an unused import
cdcf35fd
Merge pull request #26498 from jakevdp:jnp-indexing
794ae0f7
Ignore ImportError for Triton on Windows
b287c392
Fix segfault when old GPU plugins are installed.
902ebe1b
Add JAX error checking support
6addf02a
Fix breakage in indexing refactor
b93934c7
jax.lax: improve docs for pow & related functions
531443c4
Merge pull request #26542 from jakevdp:fix-breakage
ca87f5f3
Merge pull request #26528 from jakevdp:lax-docs
4b94665f
Don't write atime file if JAX_COMPILATIION_CACHE_MAX_SIZE == -1
d5d43fc4
Fix the type annotations and don't += a generator (it's confusing)
36d7f853
refactor: import numpy objects directly in jax.numpy
33b989ac
Adds support for string and binary data processing in Colocated Python.
9b6b569f
jax.lax: improve docs for bitwise operators.
29771dd0
Merge pull request #26529 from skye:atime
531d80dc
Support optimization_level and memory_fitting_level XLA compilation o…
d3850e7f
[shard_map] fix debug_print with partial auto shmap
36819604
Merge pull request #25887 from mattjj:partial-auto-something
df135d2f
[JAX] Allow pallas to accept scalar shape semaphores.
9a8c9a56
Update XLA dependency to use revision
a6c7cef9
Minor format cleanup
b6361b3e
[Pallas] Reductions with replicated axes.
eaceac3b
[Mosaic] Several fixes/improvements for the new TPU interpret mode.
962eb419
Update XLA dependency to use revision
936eb844
[TPU][Mosaic][Easy] Add verification for AssumeMultipleOp.
a6fcb741
Cache `shape` property on `Mesh`.
7bfa420d
[Mosaic GPU] Implement lowerings for `Tile` and `Transpose` transform…
52f8fbee
Update XLA dependency to use revision
bffeb5d6
Integrate LLVM at llvm/llvm-project@912b154f3a3f
e78a469b
Fix test failure in PGLE tests.
b3ed528f
Fix typos
180be997
Change type signature of lexsort in stub file to match type signature…
422b747d
Merge pull request #26570 from hawkinsp:pgle
7e27713a
Update XLA dependency to use revision
490ceba7
Update XLA dependency to use revision
9f8a0bf9
Update XLA dependency to use revision
ce399aa6
Merge pull request #26571 from rajasekharporeddy:typos
8f8e3187
Merge pull request #26550 from jakevdp:lax-docs
63bc22d6
Merge pull request #26574 from BaconBreaker:lexsort-stub-file
82216fc4
Integrate LLVM at llvm/llvm-project@9d24f9437944
725087e1
[pallas] Skip `OpsTest.test_concat_constant` if TPU is not available
d4559ba4
Merge pull request #26401 from jakevdp:numpy-consts
72f0a90e
Error unconditionally for jit, pjit and with_sharding_constraint if `…
1dc58b79
jax.lax: improve docs for comparison operators
7f115fbb
[Mosaic TPU] Support mask concat
bb68124c
[sharding_in_types] Set the `sharding_in_types` config to True. This …
00d82970
Make device_put resharding on single device array input work under us…
8bcbf585
Let users pass in pspecs to with_sharding_constraint when `use_mesh` …
1079dc44
add distributed.is_initialized
5db78e7a
Merge pull request #26172 from ZacCranko:is-distributed-init
09491e2b
Exclude auto axes whenever extending the axis_env via core.extend_axi…
c825241c
Expose `get_ty` aka get_aval from jax namespace
b3508333
Update XLA dependency to use revision
44872296
Use consistent dtype for forward and backwards in jax.nn.dot_product_…
d5e5b42d
Now that sharding_in_types config flag is True, remove the config and…
a3edfb43
Use "common" instead of "build" for some flags in .bazelrc
679f1a95
Error out if going from `Manual` -> `Auto/Explicit` AxisTypes in the …
66d04f85
roofline: Support computing flops for binary ops.
7eee2de7
Mark `in_shardings` and `out_shardings` as Any for typing reasons sin…
401fa901
fix scan doc on the `unroll` argument.
ae10f2da
Relax the check in `_mapped_axis_spec` to allow `()` and `None` to be…
1081c1f1
Merge pull request #26615 from froystig:scan-docfix
eef829cb
[Pallas] Support dynamic grids in the new TPU interpret mode
ac74857d
Relax one more check in partial_eval_jaxpr_nounits
dbb46e92
Only add new manual axes to residuals when adding axes with partial_a…
b7c66bd2
[StableHLO] Only emit StableHLO from xla_computation_to_mlir_module
9ad9c3d3
Merge pull request #26591 from jakevdp:lax-docs
cb0d326e
[sharding_in_types] Initial support for partial-auto/explicit shard_m…
8305803b
If cur_mesh is empty and AxisTypes of Mesh passed to shmap are Explic…
b6b319cd
canonicalize closed over values if **atleast** 1 mesh axis is `Manual…
262aab74
[Mosaic] Consider divisibility when doing large tiling
37af0135
Update XLA dependency to use revision
2c8a8ed7
Bump up timeout for building wheel to 60 minutes
040b718e
Added `jax.experimental.multihost_utils.live_devices` API.
ddcb7dee
[Pallas][Mosaic] Support float8_e4m3b11fnuz
b7968474
[JAX] Generate more readable error for failed device deserialization …
71f9764e
Fix wheel downloads for nightly artifact testing
6b39eb6b
CI: skip execution of convolutions.ipynb
380fb0a5
Fix head comment: was referring to nonexistent parameters.
08de0128
[pallas:mosaic_gpu] Added support for binary/comparison ops with WG s…
7438976e
Internal: improved type annotations for lu.WrappedFun
fe00aa0f
Merge pull request #26646 from jakevdp:fix-rtd
bbc4fa71
Use the mesh of `out_aval` when converting GSPMDSharding to NamedShar…
250e2ee7
Set the mesh of `tangent.aval` when we are creating `zeros_like_aval`…
bcd4048d
Internal compatibility change
b510127a
Update XLA dependency to use revision
0ef5eb06
Add dtype dispatching to FFI example.
ca23b749
[better_errors] Fix a debug_info test, and expand the docstring for t…
edf401d7
Merge pull request #26607 from dfm:ffi-example-dtype-dispatch
0e0a04fa
Reverts 9ad9c3d3c2b6b5c3ea736af9b4d8c595537de93c
65526760
Start adding primitive registration helper functions to lax.linalg.
a981e1c4
Update lax.linalg.lu primitive to use registration helper functions.
126909b6
Update internal unop primitive helper to pass kwargs to dtype rule.
09325d92
Merge pull request #26649 from jakevdp:wrapped-fun
237b7941
Merge pull request #26657 from gnecula:debug_info_test
cb8ff825
[CI] Enable workflow_dispatch for the continuous workflow
c664a0cd
Set the mesh of the sharding during broadcast in vmap so that we don'…
66037d10
Disable tests/debug_info_test.py:test_vjp_of_jit test. Currently fail…
87a7158f
Update lax.linalg.qr primitives to use registration helper functions.
ed10003a
Don't set PYTHONWARNINGS=error for tests that use TensorFlow.
673a02d6
Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbac…
2d1bc5c2
Update Windows docker image to pin it to a sha
a0736391
Use new ML Build CUDA images
5089fb01
Update several remaining lax.linalg primitives to use registration he…
c4418c10
Reverts 655267609b5589fda8358ae7aaf2eb832036407a
6c83d436
[JAX] Implement an initial object API for colocated Python
96b7dbab
roofline: Handle ClosedJaxpr instances.
b3fcba7c
Implement Jax CPU/GPU callbacks with XLA's FFI.
6e83de59
Allow casting to the same axis type
629426f8
[sharding_in_types] Make slice and ellipsis work with `.at[...].get(o…
80f18ded
Reverts 6e83de5909b5dad6827c1682726c840c9688b32d
34077851
[sharding_in_types] Allow `auto_axes` and `explicit_axes` to take num…
7c4fe2a7
Update XLA dependency to use revision
a75bea51
Implement SVD algorithm based on QR for CPU targets
e03fe3a0
[sharding_in_types] Add sharding rules for the following primitives:
d695aa4c
[better_errors] Cleanup use of DebugInfo.arg_names and result_paths
1be801ba
Update XLA dependency to use revision
e0e5457c
[pallas:mosaic_gpu] Use `{min,max}imumf` instead of `{min,max}numf`
74b2e020
[mosaic_gpu] Warmup the kernel when doing CUPTI based profiling
908ff49e
fix getting gcc major version
dd4aa79d
[sharding_in_types] Add more reshape sharding support
7d3c63ed
Update XLA dependency to use revision
b0cfcb8e
Merge pull request #26569 from gnecula:debug_info_arg_names
c17ea805
Mark error_check_test as thread-unsafe.
aeff675e
Merge pull request #25053 from JanLuca:gesvd
c74f497e
Prepare for JAX release 0.5.1
07440f4a
Update lax.linalg.svd primitive to use registration helper functions.
ae656e15
Plumb layout through the creation of PjRtArrays.
9d421c91
Fix rank promotion error in JVP of batched eigh.
6bd99207
Update JVP rule for lax.linalg.lu to use vmap instead of broadcasted_…
62530d59
Use `uv` to install Python packages
4b4f2f9c
Fix shard_map debug_nan leakage of manual out_avals in the impl rules…
6d8be966
Make mesh and *_spec parameters optional.
79e1e1fc
Merge pull request #26690 from h-vetinari:gcc
d0cb30b7
Remove extra "--system" argument
cadb5311
Share the logic of allowing propagation to input/output between JAX a…
8739232f
Merge branch 'release/0.5.1' into main
54f70724
Update version numbers after 0.5.1 release.
c8c4cfa0
Merge pull request #26708 from hawkinsp:postrelease
11b45a6f
Change documentation to recommend libtpu from pypi instead of GCS.
85add667
Merge pull request #26713 from hawkinsp:docs4
d2867333
Add sharding mismatch to explain_tracing_cache_miss
6f8bab3c
Add a param to not normalize the attention weights
8f1dd02e
Merge pull request #26715 from Rifur13:normalize
41faf51a
Sharp bits: add note on subnormal flush-to-zero
f5ca46f5
[sharding_in_types] Error out when using `auto_axes` or `explicit_axe…
b707f0bd
Pass in "JAX_" variables to the docker container
e4d63dff
Remove `-y` and extra `--pre` arguments
cc830748
[Pallas:MGPU] Consistently use i32 as the grid index type in emit_pip…
80848ad8
[pallas:mosaic_gpu] Added WG lowering rules for TMA primitives and ru…
7eadc64b
Update XLA dependency to use revision
43f2da86
[Pallas:MGPU] Avoid SMEM->GMEM wait if no outputs are transferred in …
d0d5bba6
[Pallas:MGPU] Change WG semantics convention to represent scalar arra…
71c76220
Skip tests using StringDType when NumPy version is below 2.0
7fb6788d
[Pallas:MGPU] Enable lowering for .astype and scalar broadcasts
676aceba
[pallas:mosaic_gpu] Use `emit_pipeline` for pipelining in the lowering
c13a2f95
[Mosaic GPU] Add dialect lowering logic for splat constants.
5b13883f
[better_errors] Port the Pallas debug info mechanisms to the new JAX …
c4e0db6f
[Mosaic GPU][NFC] Remove unnecessary lambda wrappers from test.
6f966397
[Mosaic GPU] Add layout inference for `arith.Ext{F,SI,UI}Op` and `ari…
5312b5e3
Update tsan suppressions.
2325cf35
Merge pull request #26058 from gnecula:debug_info_pallas
7acd60c8
[Mosaic GPU] Add layout inference for `scf.ForOp` and `scf.YieldOp`.
5024ef21
[sharding_in_types] `physical_aval` should set the correct sharding o…
9deb7e3d
[jax:custom_partitioning] Propagate static arguments to sharding_rule…
30348e90
Add tests for lax.linalg.svd algorithm specification.
a3a48af1
[Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline c…
3d87a01b
[Easy][Mosaic] Tiny refactor for clarity in getTypeBitwidth
083ffd37
Add jax.experimental._mini_mpmd
2b4c455a
Typecheck pallas.CostEstimate
0f8e6b99
Merge pull request #26564 from gspschmid:gschmid/mini_mpmd
a6b8384a
Merge pull request #26734 from hawkinsp:tsan
69a6aaa3
Merge pull request #26668 from jakevdp:sharp-bits
c06c7851
Fix incorrect line separator
7a162f2a
Create `jax` wheel build target.
eb912ad0
[Mosaic GPU] Use explicit recursion in rules instead of doing it auto…
ced28167
Pass JAXLIB_* env variables to docker container
8ac29697
Expose pallas.mosaic.random.sample_block to pltpu interface
05614edc
Deprecate alpha argument to trsm LAPACK kernel.
2ce88c95
Port many uses of contextlib.contextdecorator to explicit context man…
256e37af
rocm-repo-management-api-2
requested a review
303 days ago
Merge branch 'rocm-main' into ci-upstream-sync-127_1
45e2060b
Install numa library
e82b4e22
Fix numa package
72ecacd8
Fix numactl-devel name
1217ba90
charleshofer
approved these changes on 2025-02-26
charleshofer
merged
877deed7
into rocm-main
302 days ago
charleshofer
deleted the ci-upstream-sync-127_1 branch
302 days ago
Login to write a write a comment.
Login via GitHub
Reviewers
charleshofer
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub