jax
CI: 06/19/25 upstream sync
#476
Merged

CI: 06/19/25 upstream sync #476

charleshofer merged 854 commits into rocm-main from ci-upstream-sync-219_1
rocm-repo-management-api-2
rocm-repo-management-api-2 rocm-repo-management-api-2 requested a review 221 days ago
rocm-repo-management-api-2[bot] rocm-repo-management-api-2[bot] enabled auto-merge (rebase) 221 days ago
yashk2810 Fix pgle test breakage
6bbd6ca4
justinjfu [Pallas] Fix missing sub lowering rule for sparsecore.
3f700ca3
yashk2810 Only infer sharding from input in full_like (in eager mode) if the in…
b2550031
rdyro Make experimental pytree_serialization visible in OSS jax build
d253118f
Google-ML-Automation Update XLA dependency to use revision
6935d34a
bartchr808 #sdy Fallback to GSPMD in JAX export if the loaded module was lowered…
fb7023f2
bchetioui [Pallas/Mosaic GPU] Expose the new `TCGEN05_ROW` layout.
ea7023bd
bartchr808 #sdy Have JAX export compat tests also run on Shardy.
c46a9e35
dimitar-asenov [Mosaic GPU] Use the `mosaic_gpu.sliceSMEM` MLIR op when using WG sem…
345d19c9
ZixuanJiang Raise `NotImplementedError` instead of `ValueError` when using Shardy…
0ceca22a
danielsuo [jaxlib] Bind 'compile' to `xla::PyClient::Compile` rather than `xla:…
8b5b41db
danielsuo [jax::compiler] Bind `compiler.backend_compile` to `xla::PyClient::Co…
a44f7085
yashk2810 Reverts 6cd196a5db22b8db0ed4000e4cf67ad748bf52f3
ffe9161e
hawkinsp Fix a rare numerical flake in svd_test seen on TPU v6e.
28c25804
ZixuanJiang Add not-implemented sharding rule in `third_party/py/jax/_src/cudnn/f…
dcb926cb
yashk2810 Make `unreduced` argument in `PartitionSpec` a `set | frozenset` inst…
5562a35d
Rifur13 [Mosaic GPU] Add a test for TMA multicasts in pallas. This also effec…
64395618
cclauss Fix typos discovered by codespell
af7c6577
Google-ML-Automation Fix documentation for the CLI `up` command in the debugger.
b7313cc8
pschuh Link c-api raw buffer support into jaxlib.
bb1bedf6
sharadmv [Pallas Fuser] Add basic reshape push rule
35770432
emilyfertig [Rollback] Roll-forward with fix and test: prototype of cross-host de…
3432563d
dougalm Add a general system for keeping track of quasi-dynamic data (QDD).
735514ba
mattjj fix sharding-in-types + from_edtype
b6682d7e
Google-ML-Automation Update XLA dependency to use revision
b72310a1
bchetioui [Pallas/Mosaic GPU] Expose the new `TCGEN05_COL` layout.
09284d48
hawkinsp Don't repeatedly recompute a tuple of axis names for a membership test.
3f24bd5e
dimitar-asenov [Mosaic GPU] Move `should_have_transforms` to `inference_utils`.
79a20167
justinjfu [Pallas][Mosaic GPU] Use separate allocations for collective TMEM.
d25f9451
jakevdp jnp.array: avoid call to stack
fc894e1e
yashk2810 Cache get_vma because it's the same thing we do for `get_sharding` an…
4bb97a22
mattjj skip pytype on slow file
bcd52e83
YousefElbrolosy doc: clarified lack of gpu support for schur and sqrtm
2b969f0d
jakevdp lax_numpy: move array and asarray to their own submodule
576282f1
emilyfertig Make sure unsupported transfers between multi-process CPU arrays and …
381a0755
mattjj [cleanup] remove core.gensym, and Var.suffix
be47128b
hawkinsp Don't recompute source_info.current() in DynamicJaxprTracer.
79dd3393
justinjfu [Pallas][Mosaic GPU] Support column slicing on TMEM.
12354f31
bchetioui [Mosaic GPU] Fix `2xf32 -> 2xf8e4m3fn` conversion.
5ae4c725
justinjfu [Pallas][Mosaic GPU] Skip tcgen05 reduce test on WG semantics.
e4277634
justinjfu [Pallas][Mosaic GPU] Add support for load/broadcast using TCGEN05 ROW…
190f762e
jakevdp lax.top_k: raise error if indices will overflow
75698164
pschuh Bring back tree concat optimization for np.array(...)
1fc8ef25
Google-ML-Automation [Mosaic] Adds both direct (where hardware can) support for int8 Trans…
07faf131
bchetioui [Mosaic GPU] Add support for lowering `2xbf16 -> 2xf8e4m3fn` converts.
c5e87fff
emilyfertig Fix logic for checking supported cross-host device transfers, since t…
61d2bc65
Google-ML-Automation Update XLA dependency to use revision
b2480e7d
dimitar-asenov [Mosaic GPU] Ensure all ops that need transforms have them at the end…
e828a151
sharadmv Add fast-path for non-concrete Tracers in is_constant_dim to lower tr…
611b334c
Reverts 5c33588b30edbae51d5b63b0bd7cc8d9058d7ccb
1edc1501
dimitar-asenov [Mosaic GPU] Extract the type-related logic out of `reinterpret_smem_…
88332601
hawkinsp Port PartitionSpec to C++.
a2c63fac
hawkinsp Fix segfault if None is passed to PartitionSpec.__eq__.
e5faa078
justinjfu [Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem…
7d1b479d
hawkinsp Optimize jaxpr equation pretty-printing.
fe379689
hyeontaek [JAX] Allow registering callbacks to be called when backends are cleared
65b43fc1
yashk2810 PRNGKeyArray doesn't have a format field so assign the layout to be N…
fe0324b4
mattjj Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
5d1ba676
yashk2810 Add committed property to MutableArray
34bff904
sharadmv [Pallas TPU] Add custom_vjp_call lowering rule
6641b22d
hawkinsp Small speedups to pretty-printing.
965629f3
yashk2810 Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted whe…
ff54073a
justinjfu [Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matm…
bfddeea5
mattjj [mutable-arrays] make partial_eval_jaxpr forward input-residuals
d42fd533
Google-ML-Automation Update XLA dependency to use revision
cd5373bc
mattjj fix vestigial change that caused breakage
a1f13342
Google-ML-Automation Update XLA dependency to use revision
bb548c71
hawkinsp Port pretty-printer to C++.
56574021
Google-ML-Automation Fix typo in error message.
7cd006f4
Google-ML-Automation [Pallas][Easy] Terser printing of GridMapping unless debug is set.
2bf935a8
carlosgmartin Minor fix to doc for random.orthogonal.
37d24871
carlosgmartin Clarify argument order for lax.associative_scan when reverse=True.
21164dea
Move jax/_src/api.py and associated files to their own BUILD rule
d8911a27
bartchr808 Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea
8f256f2c
dfm Don't trigger debug_infs in ndtri unless an inf is returned.
c32b14d9
jakevdp jax.nn.standardize: improve documentation
0a0da47e
Move jax/_src/custom_batching.py to its own build rule
1d7c021b
Move jax/_src/earray.py to its own build rule
4cbb8418
Google-ML-Automation Fix spelling error in the name of the input variable.
05c7822c
Rifur13 [Mosaic GPU] Error when causal masking is used on cuda versions known…
fc51c38e
Move jax/_src/ffi.py to its own build rule
e5b6fc35
Move jax/_src/custom_partitioning.py to its own build rule
44328507
jakevdp [array-api] pin array-api-tests to 2025.05.23
b3f4104d
Move jax/_src/buffer_callback.py to its own build rule
147805ea
mattjj [mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_…
bdf171de
Move jax/_src/shard_alike.py to its own build rule
6de4f1c4
Google-ML-Automation [Pallas TPU][NFC] Use register to track buffer slots in pipeline loop
6327c899
Google-ML-Automation Remove forward_compat check for alpha as it is past the support date.
181a1f7e
toli-y Update JAX test to not rely on ToString and instead check the Device …
8d059f28
hawkinsp Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
abbf66cb
Google-ML-Automation Automated Code Change
b80319f3
Google-ML-Automation Update XLA dependency to use revision
0095ad24
gnecula [jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
8506f511
Move jax/_src/public_test_util.py to its own build rule
31f00539
bchetioui [Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the pr…
e9bfb8a1
benquike fix type annotation for _IndexUpdateRef.get in _src/basearray.pyi
f1499d22
hawkinsp Pass source_info to custom_staging_rules and into jaxpr inlining.
354c8a07
mattjj [mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd
0613a11f
hawkinsp Fix segfault if None is passed to PartitionSpec.__eq__.
e5faa078
justinjfu [Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem…
7d1b479d
hawkinsp Optimize jaxpr equation pretty-printing.
fe379689
hyeontaek [JAX] Allow registering callbacks to be called when backends are cleared
65b43fc1
yashk2810 PRNGKeyArray doesn't have a format field so assign the layout to be N…
fe0324b4
mattjj Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
5d1ba676
yashk2810 Add committed property to MutableArray
34bff904
sharadmv [Pallas TPU] Add custom_vjp_call lowering rule
6641b22d
hawkinsp Small speedups to pretty-printing.
965629f3
yashk2810 Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted whe…
ff54073a
justinjfu [Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matm…
bfddeea5
mattjj [mutable-arrays] make partial_eval_jaxpr forward input-residuals
d42fd533
Google-ML-Automation Update XLA dependency to use revision
cd5373bc
mattjj fix vestigial change that caused breakage
a1f13342
Google-ML-Automation Update XLA dependency to use revision
bb548c71
hawkinsp Port pretty-printer to C++.
56574021
Google-ML-Automation Fix typo in error message.
7cd006f4
Google-ML-Automation [Pallas][Easy] Terser printing of GridMapping unless debug is set.
2bf935a8
carlosgmartin Minor fix to doc for random.orthogonal.
37d24871
carlosgmartin Clarify argument order for lax.associative_scan when reverse=True.
21164dea
Move jax/_src/api.py and associated files to their own BUILD rule
d8911a27
bartchr808 Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea
8f256f2c
dfm Don't trigger debug_infs in ndtri unless an inf is returned.
c32b14d9
jakevdp jax.nn.standardize: improve documentation
0a0da47e
Move jax/_src/custom_batching.py to its own build rule
1d7c021b
Move jax/_src/earray.py to its own build rule
4cbb8418
Rifur13 [Mosaic GPU] Error when causal masking is used on cuda versions known…
fc51c38e
Move jax/_src/ffi.py to its own build rule
e5b6fc35
Move jax/_src/custom_partitioning.py to its own build rule
44328507
jakevdp [array-api] pin array-api-tests to 2025.05.23
b3f4104d
Move jax/_src/buffer_callback.py to its own build rule
147805ea
mattjj [mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_…
bdf171de
Move jax/_src/shard_alike.py to its own build rule
6de4f1c4
Google-ML-Automation [Pallas TPU][NFC] Use register to track buffer slots in pipeline loop
6327c899
Google-ML-Automation Remove forward_compat check for alpha as it is past the support date.
181a1f7e
toli-y Update JAX test to not rely on ToString and instead check the Device …
8d059f28
hawkinsp Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
abbf66cb
Google-ML-Automation Automated Code Change
b80319f3
gnecula [jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
8506f511
Move jax/_src/public_test_util.py to its own build rule
31f00539
bchetioui [Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the pr…
e9bfb8a1
benquike fix type annotation for _IndexUpdateRef.get in _src/basearray.pyi
f1499d22
hawkinsp Pass source_info to custom_staging_rules and into jaxpr inlining.
354c8a07
jenriver add reference to pr-checklist
99c0170c
liepieshov [jax2tf] fix jax2tf sharding tests for shardy
0d8817cd
Rifur13 [Mosaic GPU] Fix test after a previous PR changed the config params.
35e08582
jakeharmon8 Skip NumPy's `isClose` test for NumPy 2.3.0
f812cc16
mattjj fix for a downstream breakage from #29353
98aca346
subhankarshah * Add support for output and input memory space colors in tpu custom …
c4bca143
yashk2810 [JAX] Add `vma` to `ShapeDtypeStruct` constructor arguments.
1fb5bc7f
IvyZX Add is_leaf_with_path predicate.
68e209d6
Google-ML-Automation Pallas documentation fixes.
6b4914a6
mattjj Rollback of #29353 due to downstream failures
a3e9bab6
yashk2810 Initial commit to make unreduced + AD work.
2e6f2dc5
hyeontaek [JAX] Move the fallback of `colocated_cpu_devices` logic from the col…
6a2dc9b0
yashk2810 Add basic mutable array tests with AOT
955a0520
hawkinsp Save a jaxpr equation in pl.cdiv if the rhs is an int.
213b4d59
Google-ML-Automation Update XLA dependency to use revision
e11a820d
stjerngren Ensure that all attributes are restored after pickling in `NamedShard…
b56b5350
dimitar-asenov [Mosaic GPU] Use _slice_smem also for barriers.
ad7e14af
superbobry [pallas:mosaic] A few more primitives now have lowerings for all kern…
d1768fdb
dimitar-asenov [Mosaic GPU] Remove unneeded code.
5b13235d
hawkinsp Propagate source_info in more places:
1d5649e7
stjerngren Ensure that memory_kind is restored after pickling in SingleDeviceSha…
8c0e5007
hawkinsp Don't recompute np.iinfo in _scalar_type_to_dtype.
43b527f4
mooskagh Set explicit dot precision in the sparse solver test.
9bd455ea
hawkinsp Delete instantiate_const_abstracted.
830784de
yashk2810 add doc comment to vma in ShapedArray
f9ec83b2
hawkinsp Do not call update_weak_type on the result of get_aval().
0c345d14
Move jax/_src/export to its own build rule
3b9bc57a
olupton //tests:scaled_matmul_stablehlo_test: fix for xla#27096
f854d796
jakeharmon8 Move materialization of NDIndexer out of draw()
4c6a244a
hyeontaek [JAX] Extend `colocated_cpu_devices` to accept `Mesh` besides devices
84d65670
Google-ML-Automation Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client.
7a23dfb3
superbobry Migrated to mypy 1.16.0
1036afdb
rdyro extend pallas paged_attention with kv scales
c0f675a2
Google-ML-Automation Add alternative location of `CUDA_ROOT` for Bazel build/tests with he…
ef8da84d
Matt-Hurd [XProf] Change tensorboard-plugin-profile to new xprof package
ea10dd29
hawkinsp Add a pytype disable around zstandard.
9fc94c3a
yashk2810 Add execution to unreduced tests now that it works end-to-end
2f9f2a5a
kanglant Add nightly linux jax wheel tests for python 3.14.0b1
737c63a5
rdyro Fix GPU quantized paged attention tests for < sm89
f3cfec14
gnecula Improve batching for lax.platform_dependent
b714df3d
Google-ML-Automation Update XLA dependency to use revision
275928b2
hawkinsp Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
f71d63fe
bchetioui [Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`.
e3c7e83a
hawkinsp Temporarily disable AVX512 in linalg_test_cpu.
6e8e9118
Google-ML-Automation Add hermetic `nvshmem` dependencies to JAX targets.
6397393c
DanisNone add missing dtypes to jax.numpy.__init__.pyi
18011fe0
sharadmv [Pallas TPU] Support memory space constraints on pallas_call inputs.
2f594b80
Google-ML-Automation [JAX] Update the example to use jax.numpy rather than numpy.
b3c5c708
hawkinsp Reland the C++ safe_zip implementation.
bc2db638
yashk2810 Add `all_gather_invariant` to lax.
5082f634
Google-ML-Automation [Pallas TPU] Small fix to memory space constraints on pallas_call inp…
ccc9c4f6
dimitar-asenov [Mosaic GPU] Enable transpose tests in mosaic_gpu.
b98cfaeb
jakeharmon8 Add colorama back into test-requirements
aba2b654
jakevdp [doc] fix some inaccuracies in jnp.bincount docs
68b7ad16
hawkinsp Use a frozenset for unconstrained_dims in sharding_constraint_p.
1d6e9519
hawkinsp [jaxlib] Change Traceback to be a raw CPython class rather than a nan…
4191dc89
hawkinsp Make the params of more jaxpr primitives hashable.
d220227b
yashk2810 add doc comment to vma in ShapedArray
f9ec83b2
olupton //tests:scaled_matmul_stablehlo_test: fix for xla#27096
f854d796
Google-ML-Automation Add alternative location of `CUDA_ROOT` for Bazel build/tests with he…
ef8da84d
Matt-Hurd [XProf] Change tensorboard-plugin-profile to new xprof package
ea10dd29
rdyro Fix GPU quantized paged attention tests for < sm89
f3cfec14
hawkinsp Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
f71d63fe
bchetioui [Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`.
e3c7e83a
Remove unused internal optimization_barrier alias
e10d1fec
dimitar-asenov [Mosaic GPU] Convert all memrefs with transforms to unrealized casts …
d8d0efa3
jburnim Fix return type annotation for tree_util.tree_broadcast.
ecedb8d3
justinjfu [Pallas] Add no_pipelining debugging option to emit_pipeline.
e5fc134a
yashk2810 Replace `with_partitions` and `with_unreduced` with `.update` on Part…
d2ee60a0
hyeontaek [JAX] Fix the test names in colocated_python_test.py to following the…
d9446e2c
Google-ML-Automation Update XLA dependency to use revision
f43f9ee3
superbobry Removed unused `PyTreeDef::MakeFromNodeDataAndChildren` and its Pytho…
fd64ce48
gnecula [export] Add back-compat test for tridiagonal solve on GPU
ef603037
jakevdp [doc] add missing axis_types documentation
005acd2f
hawkinsp Fix a missing bounds check in traceback code.
2fa8e357
tlongeri [Mosaic:TPU][NFC] Delete unused variable
8c14f6cc
hyeontaek [JAX] Relax the return type of `colocated_python` decorator
bb0291e4
zacmustin Add custom-call ops to roofline.
d71e5fcd
sannidhyachauhan Removing Tensorflow references from the document.
bff97363
sannidhyachauhan Add test for programmatic tracing with options.
412d80fb
Google-ML-Automation Update XLA dependency to use revision
4c3f40f0
ZixuanJiang Pass through the `use_shardy_partitioner` with `jax.config.jax_use_sh…
a0aa266a
WindQAQ [Mosaic] Use BF16 ops for math::PowF on TPUv6+.
52ae9e22
jburnim Update Pallas debugging doc with TPU interpret mode + dynamic race de…
b2d89da8
Google-ML-Automation Prefer binaries in NVIDIA `nvcc` wheel over system CUDA installation …
2d2c58b0
cky9301 Add an API to overwrite the current execution_stream_id and respect i…
a3b8c508
Google-ML-Automation [Pallas TPU] Add flag to enable using registers to keep track of slot…
6a7c4d3a
superbobry [pallas] `AbstractMemoryRef` now implements all functional update met…
dbb7794d
vfdev-5 Removed fixed suppressions
cfbfd287
jakevdp jax.experimental.enable_x64: add warning to docstring
54e34127
DanisNone add jax.nn module type hints (__init__.pyi)
d17dc18f
justinjfu [Pallas][Mosaic GPU] Enable collective MMA from TMEM.
3b5c51f3
Google-ML-Automation Update XLA dependency to use revision
6803fbfd
rosiezou add psend and precv to jax/lax/parallel
1f46aa4d
Rollback https://github.com/jax-ml/jax/pull/29410 due to downstream p…
95f5eacf
zacmustin Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline.
8adb773c
hyeontaek [JAX] Remove sleeping from colocated Python execution tests
f3e5c89f
yashk2810 Prepare for JAX release 0.6.2
2a1f66b9
yashk2810 Postrelease (0.6.2) changes
98e91122
zacmustin Add `gather` to `roofline`.
4fb6f1bc
yashk2810 jaxlib_extension_version == 355 after 0.6.2 release. So remove the co…
361cf02d
tlongeri [Mosaic:TPU] Byte-granularity dynamic gathers
66795a13
superbobry [mosaic] Added a `k` prefix to `TPU_MemorySpace` members
6348ed9f
apaszke [Mosaic GPU] Rework the CUDA_ROOT detection once again
51fea783
apaszke [Mosaic GPU] Add support for s8 matmuls on Blackwell
93d8a323
bchetioui [Mosaic GPU] Implement canonicalization for `TiledLayout`s.
b0a7e8af
hawkinsp Drop Python 3.10 support.
690bf78a
yashk2810 Remove `_allow_deprecated_jit_signature` now that 0.6.2 is out and ne…
2746b769
hawkinsp Add a cache around abstract_eval rules.
56cfa598
jakevdp Finalize a number of deprecations for JAX v0.7.0
8ca8453b
Google-ML-Automation [Mosaic TPU] Make the backward-compatibility libtpu condition stricter
1a048aa0
justinjfu [Mosaic GPU] Fix minor error in matmul test.
cfdadde2
hawkinsp Remove some dangling references from the docs.
59a26100
yashk2810 Remove PositionalSharding from JAX now that 0.6.2 is out and next rel…
2420937e
kanglant Skip pytest for Python 3.14 during the JAX release process
296dd2a7
jburnim Fix rare error with Literal in DynamicJaxprTracer.full_lower.
b20b32c4
Move jax._src.lax to its own BUILD rule
94c1e891
hawkinsp Run pyupgrade --py311-plus.
b2f117e9
dubstack Fix bugs in the double_buffered_pipeline example
8e2c2abb
jakevdp [doc] fix build error
23d99fcb
kanglant Update jax requirements lock files after 0.6.2 release
f9bc8652
subhankarshah [JAX] Skip failing tpu tests until June 30th.
57858414
jakevdp Add wrap_negative_indices paramter to jnp.ndarray.at[]
24b0ae00
justinjfu [Pallas][Mosaic GPU] Add GPU pipelining docs
00e51508
Cjkkkk add cudnn sdpa mla support
60d85826
ZixuanJiang Pass shardy option through jax config.
b75e2227
charleshofer charleshofer force pushed from 1e4a0f76 to b75e2227 221 days ago
charleshofer Don't build Python 3.10
273fa845
charleshofer
charleshofer approved these changes on 2025-06-20
disabled auto-merge 220 days ago
Rebase failed
charleshofer charleshofer merged 273fa845 into rocm-main 220 days ago
charleshofer charleshofer deleted the ci-upstream-sync-219_1 branch 220 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone