jax
CI: 06/19/25 upstream sync
#476
Merged
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
854
Changes
View On
GitHub
CI: 06/19/25 upstream sync
#476
charleshofer
merged 854 commits into
rocm-main
from
ci-upstream-sync-219_1
rocm-repo-management-api-2
requested a review
221 days ago
rocm-repo-management-api-2[bot]
enabled auto-merge (rebase)
221 days ago
Fix pgle test breakage
6bbd6ca4
[Pallas] Fix missing sub lowering rule for sparsecore.
3f700ca3
Only infer sharding from input in full_like (in eager mode) if the in…
b2550031
Make experimental pytree_serialization visible in OSS jax build
d253118f
Update XLA dependency to use revision
6935d34a
#sdy Fallback to GSPMD in JAX export if the loaded module was lowered…
fb7023f2
[Pallas/Mosaic GPU] Expose the new `TCGEN05_ROW` layout.
ea7023bd
#sdy Have JAX export compat tests also run on Shardy.
c46a9e35
[Mosaic GPU] Use the `mosaic_gpu.sliceSMEM` MLIR op when using WG sem…
345d19c9
Raise `NotImplementedError` instead of `ValueError` when using Shardy…
0ceca22a
[jaxlib] Bind 'compile' to `xla::PyClient::Compile` rather than `xla:…
8b5b41db
[jax::compiler] Bind `compiler.backend_compile` to `xla::PyClient::Co…
a44f7085
Reverts 6cd196a5db22b8db0ed4000e4cf67ad748bf52f3
ffe9161e
Fix a rare numerical flake in svd_test seen on TPU v6e.
28c25804
Add not-implemented sharding rule in `third_party/py/jax/_src/cudnn/f…
dcb926cb
Make `unreduced` argument in `PartitionSpec` a `set | frozenset` inst…
5562a35d
[Mosaic GPU] Add a test for TMA multicasts in pallas. This also effec…
64395618
Fix typos discovered by codespell
af7c6577
Fix documentation for the CLI `up` command in the debugger.
b7313cc8
Link c-api raw buffer support into jaxlib.
bb1bedf6
[Pallas Fuser] Add basic reshape push rule
35770432
[Rollback] Roll-forward with fix and test: prototype of cross-host de…
3432563d
Add a general system for keeping track of quasi-dynamic data (QDD).
735514ba
fix sharding-in-types + from_edtype
b6682d7e
Update XLA dependency to use revision
b72310a1
[Pallas/Mosaic GPU] Expose the new `TCGEN05_COL` layout.
09284d48
Don't repeatedly recompute a tuple of axis names for a membership test.
3f24bd5e
[Mosaic GPU] Move `should_have_transforms` to `inference_utils`.
79a20167
[Pallas][Mosaic GPU] Use separate allocations for collective TMEM.
d25f9451
jnp.array: avoid call to stack
fc894e1e
Cache get_vma because it's the same thing we do for `get_sharding` an…
4bb97a22
skip pytype on slow file
bcd52e83
doc: clarified lack of gpu support for schur and sqrtm
2b969f0d
lax_numpy: move array and asarray to their own submodule
576282f1
Make sure unsupported transfers between multi-process CPU arrays and …
381a0755
[cleanup] remove core.gensym, and Var.suffix
be47128b
Don't recompute source_info.current() in DynamicJaxprTracer.
79dd3393
[Pallas][Mosaic GPU] Support column slicing on TMEM.
12354f31
[Mosaic GPU] Fix `2xf32 -> 2xf8e4m3fn` conversion.
5ae4c725
[Pallas][Mosaic GPU] Skip tcgen05 reduce test on WG semantics.
e4277634
[Pallas][Mosaic GPU] Add support for load/broadcast using TCGEN05 ROW…
190f762e
lax.top_k: raise error if indices will overflow
75698164
Bring back tree concat optimization for np.array(...)
1fc8ef25
[Mosaic] Adds both direct (where hardware can) support for int8 Trans…
07faf131
[Mosaic GPU] Add support for lowering `2xbf16 -> 2xf8e4m3fn` converts.
c5e87fff
Fix logic for checking supported cross-host device transfers, since t…
61d2bc65
Update XLA dependency to use revision
b2480e7d
[Mosaic GPU] Ensure all ops that need transforms have them at the end…
e828a151
Add fast-path for non-concrete Tracers in is_constant_dim to lower tr…
611b334c
Reverts 5c33588b30edbae51d5b63b0bd7cc8d9058d7ccb
1edc1501
[Mosaic GPU] Extract the type-related logic out of `reinterpret_smem_…
88332601
Port PartitionSpec to C++.
a2c63fac
Fix segfault if None is passed to PartitionSpec.__eq__.
e5faa078
[Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem…
7d1b479d
Optimize jaxpr equation pretty-printing.
fe379689
[JAX] Allow registering callbacks to be called when backends are cleared
65b43fc1
PRNGKeyArray doesn't have a format field so assign the layout to be N…
fe0324b4
Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
5d1ba676
Add committed property to MutableArray
34bff904
[Pallas TPU] Add custom_vjp_call lowering rule
6641b22d
Small speedups to pretty-printing.
965629f3
Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted whe…
ff54073a
[Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matm…
bfddeea5
[mutable-arrays] make partial_eval_jaxpr forward input-residuals
d42fd533
Update XLA dependency to use revision
cd5373bc
fix vestigial change that caused breakage
a1f13342
Update XLA dependency to use revision
bb548c71
Port pretty-printer to C++.
56574021
Fix typo in error message.
7cd006f4
[Pallas][Easy] Terser printing of GridMapping unless debug is set.
2bf935a8
Minor fix to doc for random.orthogonal.
37d24871
Clarify argument order for lax.associative_scan when reverse=True.
21164dea
Move jax/_src/api.py and associated files to their own BUILD rule
d8911a27
Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea
8f256f2c
Don't trigger debug_infs in ndtri unless an inf is returned.
c32b14d9
jax.nn.standardize: improve documentation
0a0da47e
Move jax/_src/custom_batching.py to its own build rule
1d7c021b
Move jax/_src/earray.py to its own build rule
4cbb8418
Fix spelling error in the name of the input variable.
05c7822c
[Mosaic GPU] Error when causal masking is used on cuda versions known…
fc51c38e
Move jax/_src/ffi.py to its own build rule
e5b6fc35
Move jax/_src/custom_partitioning.py to its own build rule
44328507
[array-api] pin array-api-tests to 2025.05.23
b3f4104d
Move jax/_src/buffer_callback.py to its own build rule
147805ea
[mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_…
bdf171de
Move jax/_src/shard_alike.py to its own build rule
6de4f1c4
[Pallas TPU][NFC] Use register to track buffer slots in pipeline loop
6327c899
Remove forward_compat check for alpha as it is past the support date.
181a1f7e
Update JAX test to not rely on ToString and instead check the Device …
8d059f28
Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
abbf66cb
Automated Code Change
b80319f3
Update XLA dependency to use revision
0095ad24
[jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
8506f511
Move jax/_src/public_test_util.py to its own build rule
31f00539
[Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the pr…
e9bfb8a1
fix type annotation for _IndexUpdateRef.get in _src/basearray.pyi
f1499d22
Pass source_info to custom_staging_rules and into jaxpr inlining.
354c8a07
[mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd
0613a11f
Fix segfault if None is passed to PartitionSpec.__eq__.
e5faa078
[Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem…
7d1b479d
Optimize jaxpr equation pretty-printing.
fe379689
[JAX] Allow registering callbacks to be called when backends are cleared
65b43fc1
PRNGKeyArray doesn't have a format field so assign the layout to be N…
fe0324b4
Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
5d1ba676
Add committed property to MutableArray
34bff904
[Pallas TPU] Add custom_vjp_call lowering rule
6641b22d
Small speedups to pretty-printing.
965629f3
Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted whe…
ff54073a
[Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matm…
bfddeea5
[mutable-arrays] make partial_eval_jaxpr forward input-residuals
d42fd533
Update XLA dependency to use revision
cd5373bc
fix vestigial change that caused breakage
a1f13342
Update XLA dependency to use revision
bb548c71
Port pretty-printer to C++.
56574021
Fix typo in error message.
7cd006f4
[Pallas][Easy] Terser printing of GridMapping unless debug is set.
2bf935a8
Minor fix to doc for random.orthogonal.
37d24871
Clarify argument order for lax.associative_scan when reverse=True.
21164dea
Move jax/_src/api.py and associated files to their own BUILD rule
d8911a27
Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea
8f256f2c
Don't trigger debug_infs in ndtri unless an inf is returned.
c32b14d9
jax.nn.standardize: improve documentation
0a0da47e
Move jax/_src/custom_batching.py to its own build rule
1d7c021b
Move jax/_src/earray.py to its own build rule
4cbb8418
[Mosaic GPU] Error when causal masking is used on cuda versions known…
fc51c38e
Move jax/_src/ffi.py to its own build rule
e5b6fc35
Move jax/_src/custom_partitioning.py to its own build rule
44328507
[array-api] pin array-api-tests to 2025.05.23
b3f4104d
Move jax/_src/buffer_callback.py to its own build rule
147805ea
[mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_…
bdf171de
Move jax/_src/shard_alike.py to its own build rule
6de4f1c4
[Pallas TPU][NFC] Use register to track buffer slots in pipeline loop
6327c899
Remove forward_compat check for alpha as it is past the support date.
181a1f7e
Update JAX test to not rely on ToString and instead check the Device …
8d059f28
Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
abbf66cb
Automated Code Change
b80319f3
[jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
8506f511
Move jax/_src/public_test_util.py to its own build rule
31f00539
[Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the pr…
e9bfb8a1
fix type annotation for _IndexUpdateRef.get in _src/basearray.pyi
f1499d22
Pass source_info to custom_staging_rules and into jaxpr inlining.
354c8a07
add reference to pr-checklist
99c0170c
[jax2tf] fix jax2tf sharding tests for shardy
0d8817cd
[Mosaic GPU] Fix test after a previous PR changed the config params.
35e08582
Skip NumPy's `isClose` test for NumPy 2.3.0
f812cc16
fix for a downstream breakage from #29353
98aca346
* Add support for output and input memory space colors in tpu custom …
c4bca143
[JAX] Add `vma` to `ShapeDtypeStruct` constructor arguments.
1fb5bc7f
Add is_leaf_with_path predicate.
68e209d6
Pallas documentation fixes.
6b4914a6
Rollback of #29353 due to downstream failures
a3e9bab6
Initial commit to make unreduced + AD work.
2e6f2dc5
[JAX] Move the fallback of `colocated_cpu_devices` logic from the col…
6a2dc9b0
Add basic mutable array tests with AOT
955a0520
Save a jaxpr equation in pl.cdiv if the rhs is an int.
213b4d59
Update XLA dependency to use revision
e11a820d
Ensure that all attributes are restored after pickling in `NamedShard…
b56b5350
[Mosaic GPU] Use _slice_smem also for barriers.
ad7e14af
[pallas:mosaic] A few more primitives now have lowerings for all kern…
d1768fdb
[Mosaic GPU] Remove unneeded code.
5b13235d
Propagate source_info in more places:
1d5649e7
Ensure that memory_kind is restored after pickling in SingleDeviceSha…
8c0e5007
Don't recompute np.iinfo in _scalar_type_to_dtype.
43b527f4
Set explicit dot precision in the sparse solver test.
9bd455ea
Delete instantiate_const_abstracted.
830784de
add doc comment to vma in ShapedArray
f9ec83b2
Do not call update_weak_type on the result of get_aval().
0c345d14
Move jax/_src/export to its own build rule
3b9bc57a
//tests:scaled_matmul_stablehlo_test: fix for xla#27096
f854d796
Move materialization of NDIndexer out of draw()
4c6a244a
[JAX] Extend `colocated_cpu_devices` to accept `Mesh` besides devices
84d65670
Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client.
7a23dfb3
Migrated to mypy 1.16.0
1036afdb
extend pallas paged_attention with kv scales
c0f675a2
Add alternative location of `CUDA_ROOT` for Bazel build/tests with he…
ef8da84d
[XProf] Change tensorboard-plugin-profile to new xprof package
ea10dd29
Add a pytype disable around zstandard.
9fc94c3a
Add execution to unreduced tests now that it works end-to-end
2f9f2a5a
Add nightly linux jax wheel tests for python 3.14.0b1
737c63a5
Fix GPU quantized paged attention tests for < sm89
f3cfec14
Improve batching for lax.platform_dependent
b714df3d
Update XLA dependency to use revision
275928b2
Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
f71d63fe
[Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`.
e3c7e83a
Temporarily disable AVX512 in linalg_test_cpu.
6e8e9118
Add hermetic `nvshmem` dependencies to JAX targets.
6397393c
add missing dtypes to jax.numpy.__init__.pyi
18011fe0
[Pallas TPU] Support memory space constraints on pallas_call inputs.
2f594b80
[JAX] Update the example to use jax.numpy rather than numpy.
b3c5c708
Reland the C++ safe_zip implementation.
bc2db638
Add `all_gather_invariant` to lax.
5082f634
[Pallas TPU] Small fix to memory space constraints on pallas_call inp…
ccc9c4f6
[Mosaic GPU] Enable transpose tests in mosaic_gpu.
b98cfaeb
Add colorama back into test-requirements
aba2b654
[doc] fix some inaccuracies in jnp.bincount docs
68b7ad16
Use a frozenset for unconstrained_dims in sharding_constraint_p.
1d6e9519
[jaxlib] Change Traceback to be a raw CPython class rather than a nan…
4191dc89
Make the params of more jaxpr primitives hashable.
d220227b
add doc comment to vma in ShapedArray
f9ec83b2
//tests:scaled_matmul_stablehlo_test: fix for xla#27096
f854d796
Add alternative location of `CUDA_ROOT` for Bazel build/tests with he…
ef8da84d
[XProf] Change tensorboard-plugin-profile to new xprof package
ea10dd29
Fix GPU quantized paged attention tests for < sm89
f3cfec14
Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
f71d63fe
[Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`.
e3c7e83a
Remove unused internal optimization_barrier alias
e10d1fec
[Mosaic GPU] Convert all memrefs with transforms to unrealized casts …
d8d0efa3
Fix return type annotation for tree_util.tree_broadcast.
ecedb8d3
[Pallas] Add no_pipelining debugging option to emit_pipeline.
e5fc134a
Replace `with_partitions` and `with_unreduced` with `.update` on Part…
d2ee60a0
[JAX] Fix the test names in colocated_python_test.py to following the…
d9446e2c
Update XLA dependency to use revision
f43f9ee3
Removed unused `PyTreeDef::MakeFromNodeDataAndChildren` and its Pytho…
fd64ce48
[export] Add back-compat test for tridiagonal solve on GPU
ef603037
[doc] add missing axis_types documentation
005acd2f
Fix a missing bounds check in traceback code.
2fa8e357
[Mosaic:TPU][NFC] Delete unused variable
8c14f6cc
[JAX] Relax the return type of `colocated_python` decorator
bb0291e4
Add custom-call ops to roofline.
d71e5fcd
Removing Tensorflow references from the document.
bff97363
Add test for programmatic tracing with options.
412d80fb
Update XLA dependency to use revision
4c3f40f0
Pass through the `use_shardy_partitioner` with `jax.config.jax_use_sh…
a0aa266a
[Mosaic] Use BF16 ops for math::PowF on TPUv6+.
52ae9e22
Update Pallas debugging doc with TPU interpret mode + dynamic race de…
b2d89da8
Prefer binaries in NVIDIA `nvcc` wheel over system CUDA installation …
2d2c58b0
Add an API to overwrite the current execution_stream_id and respect i…
a3b8c508
[Pallas TPU] Add flag to enable using registers to keep track of slot…
6a7c4d3a
[pallas] `AbstractMemoryRef` now implements all functional update met…
dbb7794d
Removed fixed suppressions
cfbfd287
jax.experimental.enable_x64: add warning to docstring
54e34127
add jax.nn module type hints (__init__.pyi)
d17dc18f
[Pallas][Mosaic GPU] Enable collective MMA from TMEM.
3b5c51f3
Update XLA dependency to use revision
6803fbfd
add psend and precv to jax/lax/parallel
1f46aa4d
Rollback https://github.com/jax-ml/jax/pull/29410 due to downstream p…
95f5eacf
Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline.
8adb773c
[JAX] Remove sleeping from colocated Python execution tests
f3e5c89f
Prepare for JAX release 0.6.2
2a1f66b9
Postrelease (0.6.2) changes
98e91122
Add `gather` to `roofline`.
4fb6f1bc
jaxlib_extension_version == 355 after 0.6.2 release. So remove the co…
361cf02d
[Mosaic:TPU] Byte-granularity dynamic gathers
66795a13
[mosaic] Added a `k` prefix to `TPU_MemorySpace` members
6348ed9f
[Mosaic GPU] Rework the CUDA_ROOT detection once again
51fea783
[Mosaic GPU] Add support for s8 matmuls on Blackwell
93d8a323
[Mosaic GPU] Implement canonicalization for `TiledLayout`s.
b0a7e8af
Drop Python 3.10 support.
690bf78a
Remove `_allow_deprecated_jit_signature` now that 0.6.2 is out and ne…
2746b769
Add a cache around abstract_eval rules.
56cfa598
Finalize a number of deprecations for JAX v0.7.0
8ca8453b
[Mosaic TPU] Make the backward-compatibility libtpu condition stricter
1a048aa0
[Mosaic GPU] Fix minor error in matmul test.
cfdadde2
Remove some dangling references from the docs.
59a26100
Remove PositionalSharding from JAX now that 0.6.2 is out and next rel…
2420937e
Skip pytest for Python 3.14 during the JAX release process
296dd2a7
Fix rare error with Literal in DynamicJaxprTracer.full_lower.
b20b32c4
Move jax._src.lax to its own BUILD rule
94c1e891
Run pyupgrade --py311-plus.
b2f117e9
Fix bugs in the double_buffered_pipeline example
8e2c2abb
[doc] fix build error
23d99fcb
Update jax requirements lock files after 0.6.2 release
f9bc8652
[JAX] Skip failing tpu tests until June 30th.
57858414
Add wrap_negative_indices paramter to jnp.ndarray.at[]
24b0ae00
[Pallas][Mosaic GPU] Add GPU pipelining docs
00e51508
add cudnn sdpa mla support
60d85826
Pass shardy option through jax config.
b75e2227
charleshofer
force pushed
from
1e4a0f76
to
b75e2227
221 days ago
Don't build Python 3.10
273fa845
charleshofer
approved these changes on 2025-06-20
disabled auto-merge
220 days ago
Rebase failed
charleshofer
merged
273fa845
into rocm-main
220 days ago
charleshofer
deleted the ci-upstream-sync-219_1 branch
220 days ago
Login to write a write a comment.
Login via GitHub
Reviewers
charleshofer
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub