jax
CI: 06/20/25 upstream sync
#477
Open

CI: 06/20/25 upstream sync #477

rocm-repo-management-api-2 wants to merge 2162 commits into main from ci-upstream-sync-220_1
rocm-repo-management-api-2
sharadmv Add fast-path for non-concrete Tracers in is_constant_dim to lower tr…
795bbfc6
Reverts 5c33588b30edbae51d5b63b0bd7cc8d9058d7ccb
c1bb095c
dimitar-asenov [Mosaic GPU] Extract the type-related logic out of `reinterpret_smem_…
64ba9bc4
hawkinsp Port PartitionSpec to C++.
8478fe6e
hawkinsp Fix segfault if None is passed to PartitionSpec.__eq__.
e4c8da1e
justinjfu [Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem…
09111c8e
hawkinsp Optimize jaxpr equation pretty-printing.
7dd721a8
hyeontaek [JAX] Allow registering callbacks to be called when backends are cleared
58a19372
yashk2810 PRNGKeyArray doesn't have a format field so assign the layout to be N…
ec61161d
mattjj Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
c9289ae1
yashk2810 Add committed property to MutableArray
567d61e2
sharadmv [Pallas TPU] Add custom_vjp_call lowering rule
3248d55a
hawkinsp Small speedups to pretty-printing.
051386f0
yashk2810 Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted whe…
0d1edcce
justinjfu [Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matm…
6ae26141
mattjj [mutable-arrays] make partial_eval_jaxpr forward input-residuals
3699aa74
Google-ML-Automation Merge pull request #29311 from mattjj:mutable-array-custom-vjp
0e9c0e4e
carlosgmartin Clarify argument order for lax.associative_scan when reverse=True.
58faffd0
Google-ML-Automation Update XLA dependency to use revision
1cb18ec2
mattjj fix vestigial change that caused breakage
cb2b2179
Google-ML-Automation Update XLA dependency to use revision
0d1b1ef7
hawkinsp Port pretty-printer to C++.
c2a56909
Google-ML-Automation Fix typo in error message.
d69086db
carlosgmartin Minor fix to doc for random.orthogonal.
880dd131
Google-ML-Automation [Pallas][Easy] Terser printing of GridMapping unless debug is set.
d8317b53
dfm Don't trigger debug_infs in ndtri unless an inf is returned.
2281455a
gnecula [jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
a6d95ee3
Google-ML-Automation Merge pull request #29321 from carlosgmartin:fix_random_orthogonal_doc
a572a997
Google-ML-Automation Merge pull request #29313 from carlosgmartin:document_associative_sca…
74e362eb
Move jax/_src/api.py and associated files to their own BUILD rule
6d729fe7
bartchr808 Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea
92d5fe88
jakevdp jax.nn.standardize: improve documentation
88de1e61
Google-ML-Automation Merge pull request #29330 from dfm:ndtri-infs
52704b77
Google-ML-Automation Merge pull request #29332 from jakevdp:standardize-doc
3157265e
Move jax/_src/custom_batching.py to its own build rule
42ea2ac3
Move jax/_src/earray.py to its own build rule
89e0c7ec
jakevdp [array-api] pin array-api-tests to 2025.05.23
9651e608
Google-ML-Automation Fix spelling error in the name of the input variable.
7c35300e
Rifur13 [Mosaic GPU] Error when causal masking is used on cuda versions known…
39de7159
Move jax/_src/ffi.py to its own build rule
9a591624
Move jax/_src/custom_partitioning.py to its own build rule
bfd0744f
Google-ML-Automation Merge pull request #29344 from jakevdp:array-api-test-pin
a7c67e5a
Move jax/_src/buffer_callback.py to its own build rule
df50cd7e
mattjj [mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_…
31e99987
Google-ML-Automation Merge pull request #29352 from mattjj:mutable-array-custom-vjp-fixes
5f2ce853
Move jax/_src/shard_alike.py to its own build rule
a58b27c6
Google-ML-Automation [Pallas TPU][NFC] Use register to track buffer slots in pipeline loop
14aaa45f
jenriver add reference to pr-checklist
34cee968
Google-ML-Automation Remove forward_compat check for alpha as it is past the support date.
b22be866
toli-y Update JAX test to not rely on ToString and instead check the Device …
28c31b8c
jburnim [pallas] Fix shard_map + Megacore in TPU interpret mode.
65c14a2c
hawkinsp Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
c126ee3d
Google-ML-Automation Automated Code Change
cc971e35
Google-ML-Automation Update XLA dependency to use revision
b59a97f0
Google-ML-Automation Merge pull request #28916 from gnecula:tf_version
297ee7c9
gnecula Improve batching for lax.platform_dependent
31ef2cf2
Move jax/_src/public_test_util.py to its own build rule
be23dcf2
bchetioui [Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the pr…
f053dfed
Google-ML-Automation Merge pull request #29257 from benquike:main
fdf7a2cb
hawkinsp Pass source_info to custom_staging_rules and into jaxpr inlining.
4846ed2f
mattjj [mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd
5d64c395
Google-ML-Automation Merge pull request #29353 from mattjj:mutable-array-custom-vjp-scan2
56f3293c
Google-ML-Automation Merge pull request #29358 from jenriver:documentation
14a9b590
liepieshov [jax2tf] fix jax2tf sharding tests for shardy
b999fab1
Rifur13 [Mosaic GPU] Fix test after a previous PR changed the config params.
9e0472c1
jakeharmon8 Skip NumPy's `isClose` test for NumPy 2.3.0
c0541354
mattjj fix for a downstream breakage from #29353
03b01526
subhankarshah * Add support for output and input memory space colors in tpu custom …
89fca529
yashk2810 [JAX] Add `vma` to `ShapeDtypeStruct` constructor arguments.
bdd635a6
IvyZX Add is_leaf_with_path predicate.
160e59f1
Google-ML-Automation Pallas documentation fixes.
423aafe7
mattjj Rollback of #29353 due to downstream failures
ed03f383
yashk2810 Initial commit to make unreduced + AD work.
22ceb687
hyeontaek [JAX] Move the fallback of `colocated_cpu_devices` logic from the col…
d9e0244c
yashk2810 Add basic mutable array tests with AOT
fd6d90a7
hawkinsp Save a jaxpr equation in pl.cdiv if the rhs is an int.
f211c6b9
Google-ML-Automation Update XLA dependency to use revision
cd4a0c6e
stjerngren Ensure that all attributes are restored after pickling in `NamedShard…
02cdc7ba
dimitar-asenov [Mosaic GPU] Use _slice_smem also for barriers.
9da048ec
superbobry [pallas:mosaic] A few more primitives now have lowerings for all kern…
de82f9f4
dimitar-asenov [Mosaic GPU] Remove unneeded code.
827b855b
hawkinsp Propagate source_info in more places:
0edfb44f
stjerngren Ensure that memory_kind is restored after pickling in SingleDeviceSha…
1cd49920
hawkinsp Don't recompute np.iinfo in _scalar_type_to_dtype.
cfebc498
mooskagh Set explicit dot precision in the sparse solver test.
5f0e7e47
superbobry Migrated to mypy 1.16.0
225f0c62
hawkinsp Delete instantiate_const_abstracted.
f7f2ce5c
Google-ML-Automation Merge pull request #29350 from jburnim:jburnim_interpret_shard_map_pl…
f519ad7c
yashk2810 add doc comment to vma in ShapedArray
6c9bcfc6
hawkinsp Do not call update_weak_type on the result of get_aval().
b87ea1c5
Move jax/_src/export to its own build rule
772cde81
Google-ML-Automation Merge pull request #29294 from olupton:fix-scaled-matmul-stablehlo-test
703cf91b
jakeharmon8 Move materialization of NDIndexer out of draw()
2618d9b1
hyeontaek [JAX] Extend `colocated_cpu_devices` to accept `Mesh` besides devices
5c2a3201
rdyro extend pallas paged_attention with kv scales
9f8be25f
Google-ML-Automation Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client.
e61ee7e2
Google-ML-Automation Merge pull request #29405 from superbobry:union-fix
5b49b289
Google-ML-Automation Merge pull request #29354 from rdyro:paged_attention_with_scales
b40d79c2
Matt-Hurd [XProf] Change tensorboard-plugin-profile to new xprof package
97ab9e0b
Google-ML-Automation Add alternative location of `CUDA_ROOT` for Bazel build/tests with he…
004b5480
Google-ML-Automation Merge pull request #29129 from Matt-Hurd:rename_to_xprof
d00b1cad
hawkinsp Add a pytype disable around zstandard.
3ef4db4d
yashk2810 Add execution to unreduced tests now that it works end-to-end
cc976cb0
kanglant Add nightly linux jax wheel tests for python 3.14.0b1
45e61d8c
rdyro Fix GPU quantized paged attention tests for < sm89
e7d252dc
Google-ML-Automation Merge pull request #29423 from rdyro:gpu_paged_fix
2c467d6a
Google-ML-Automation Merge pull request #29362 from gnecula:platform_index_linearize
1886a9e1
Google-ML-Automation Update XLA dependency to use revision
4ad4d4a8
DanisNone add jax.nn module type hints (__init__.pyi)
87ce7d1b
hawkinsp Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
a2880363
bchetioui [Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`.
83c292bd
DanisNone add missing dtypes to jax.numpy.__init__.pyi
ef106037
hawkinsp Temporarily disable AVX512 in linalg_test_cpu.
4e3bf29b
Google-ML-Automation Add hermetic `nvshmem` dependencies to JAX targets.
294d86b2
Google-ML-Automation Merge pull request #29425 from DanisNone:miss-dtypes
46b9eadd
sharadmv [Pallas TPU] Support memory space constraints on pallas_call inputs.
f81f2589
Google-ML-Automation [JAX] Update the example to use jax.numpy rather than numpy.
1228053f
hawkinsp Reland the C++ safe_zip implementation.
07193467
yashk2810 Add `all_gather_invariant` to lax.
6e2977d8
Google-ML-Automation [Pallas TPU] Small fix to memory space constraints on pallas_call inp…
a9fdb768
dimitar-asenov [Mosaic GPU] Enable transpose tests in mosaic_gpu.
094b66fa
jakevdp [doc] fix some inaccuracies in jnp.bincount docs
efbedc61
jakeharmon8 Add colorama back into test-requirements
abb756d4
Google-ML-Automation Merge pull request #29441 from jakevdp:bincount-doc
109519a0
hawkinsp Use a frozenset for unconstrained_dims in sharding_constraint_p.
688c3d33
hawkinsp [jaxlib] Change Traceback to be a raw CPython class rather than a nan…
b72be575
hawkinsp Make the params of more jaxpr primitives hashable.
e04cc283
Remove unused internal optimization_barrier alias
d840447a
sharadmv fix-forward for pallas tpu memory spaces test
382b3e02
Google-ML-Automation Update XLA dependency to use revision
c86fefb5
Move jax._src.callback to its own BUILD rule
142ace2c
dimitar-asenov [Mosaic GPU] Convert all memrefs with transforms to unrealized casts …
604b6048
dimitar-asenov [Mosaic GPU] Add a Mosaic GPU op `with_transforms` for manually setti…
fc8192ca
dimitar-asenov [Mosaic GPU] Resolve different tile transforms using the largest comm…
a4f0e402
dimitar-asenov [Mosaic GPU] Use warpgroup semantics for the ragged dot example kernel.
70c90a95
jakeharmon8 Disable `too_slow` in data.draw() for test_ndindexer
9b45aacb
dimitar-asenov [Mosaic GPU] Reconcile the swizzle of the a and b operands for wgmma …
a3542f88
j-towns Add pjit_p to extend.core.primitives
7b7a5d8c
jburnim Fix return type annotation for tree_util.tree_broadcast.
e66a6dd2
dimitar-asenov [Mosaic GPU] Parametrize the `test_subview` test.
193f11db
hawkinsp Make params of several pallas primitives hashable.
670ae13d
yashk2810 Improve reshape not supported error message
92aedb28
jakevdp Internal refactor: move TPU lowering rules out of jax/_src/lax
1b1e9f71
hawkinsp Make params of assert_consumed_value_p hashable.
64c95744
Google-ML-Automation Merge pull request #29456 from j-towns:pjit-p-extend-primitives
ec637fe4
Google-ML-Automation Merge pull request #28931 from olupton:cublas-cudnn
f41fc71c
Google-ML-Automation Merge pull request #29420 from jakevdp:tpu-linalg
29eb8323
hawkinsp Make some remaining jaxpr equation params hashable.
790540c7
mwhittaker Implemented cross-host memory transfer on GPU.
b9cf0af5
hawkinsp Load CUDA libraries up front with cdll.LoadLibrary().
080294cf
Google-ML-Automation Merge pull request #29462 from hawkinsp:nvrtc
a849d216
justinjfu [Pallas] Add no_pipelining debugging option to emit_pipeline.
dfc90529
yashk2810 Replace `with_partitions` and `with_unreduced` with `.update` on Part…
8e346ddb
yashk2810 Remove `with_spec` from NamedSharding and replace with `.update`
7904b86b
hyeontaek [JAX] Fix the test names in colocated_python_test.py to following the…
f06888f9
Google-ML-Automation Update XLA dependency to use revision
e28d6ed9
hawkinsp Make mosaic_gpu equation params hashable.
d065d2ac
Google-ML-Automation Update XLA dependency to use revision
9678a764
Google-ML-Automation Update XLA dependency to use revision
b25655cf
Cjkkkk PR #28102: Add cudnn paged attention support in JAX cuDNN SDPA API
639216fa
yashk2810 Add `update_vma` and `update_weak_type` override on AbstractTMEMRef s…
53849a56
yashk2810 Add version guards to testAutoPgle
1362f7f5
berkinilbeyi Bump the libtpu check to 6/20
173574a9
superbobry Removed unused `PyTreeDef::MakeFromNodeDataAndChildren` and its Pytho…
0e7c96a5
basioli-k [jax] Increase absolute test tolerance for lax_control_flow test
dbde6c44
hawkinsp Fix some more instances of unhashable jaxpr equation arguments.
49e52c08
gnecula [export] Add back-compat test for tridiagonal solve on GPU
e5af0880
mwhittaker Set heartbeat_timeout argument and flag.
a78c6a70
kanglant Install SciPy from its source (head) to test against Python 3.14.0b1
09d903fc
jakevdp [doc] add missing axis_types documentation
7aec14f1
dfm Remove legacy CPU custom calls.
8865ee66
Google-ML-Automation Merge pull request #29497 from jakevdp:make-mesh-doc
35ae958b
hawkinsp Fix a missing bounds check in traceback code.
cb1cc379
vfdev-5 Removed fixed suppressions
97e580c8
tlongeri [Mosaic:TPU][NFC] Delete unused variable
748b39f4
hyeontaek [JAX] Relax the return type of `colocated_python` decorator
b6225124
zacmustin Add custom-call ops to roofline.
2dfaceda
sannidhyachauhan Removing Tensorflow references from the document.
0f5cdba0
sannidhyachauhan Add test for programmatic tracing with options.
acf99e5e
Google-ML-Automation Update XLA dependency to use revision
4be1402b
ZixuanJiang Pass through the `use_shardy_partitioner` with `jax.config.jax_use_sh…
93453f54
WindQAQ [Mosaic] Use BF16 ops for math::PowF on TPUv6+.
91651f80
jburnim Update Pallas debugging doc with TPU interpret mode + dynamic race de…
077f2b66
Google-ML-Automation Prefer binaries in NVIDIA `nvcc` wheel over system CUDA installation …
124c723a
cky9301 Add an API to overwrite the current execution_stream_id and respect i…
332aa354
jakevdp jax.experimental.enable_x64: add warning to docstring
f22896ac
Google-ML-Automation [Pallas TPU] Add flag to enable using registers to keep track of slot…
0fd08213
rosiezou add psend and precv to jax/lax/parallel
784be1f9
superbobry [pallas] `AbstractMemoryRef` now implements all functional update met…
c2cc9f9c
Google-ML-Automation Merge pull request #29504 from vfdev-5:tsan-ft-removed-fixed-suppression
3d37b0d7
Google-ML-Automation Merge pull request #29516 from jakevdp:enable-x64-warning
353e7fac
Google-ML-Automation Merge pull request #29410 from DanisNone:nn-type
dc9ef614
justinjfu [Pallas][Mosaic GPU] Enable collective MMA from TMEM.
02688e18
Google-ML-Automation Update XLA dependency to use revision
e4de90e6
yashk2810 Prepare for JAX release 0.6.2
8f81490a
Google-ML-Automation Merge pull request #29135 from rosiezou:main
755bb676
Rollback https://github.com/jax-ml/jax/pull/29410 due to downstream p…
7dd13d70
zacmustin Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline.
c944c656
hyeontaek [JAX] Remove sleeping from colocated Python execution tests
19f34a06
yashk2810 Postrelease (0.6.2) changes
feab6f4b
Google-ML-Automation Merge pull request #29528 from jax-ml:postrelease
8b88cc8e
zacmustin Add `gather` to `roofline`.
34d88cc5
yashk2810 jaxlib_extension_version == 355 after 0.6.2 release. So remove the co…
1d077471
tlongeri [Mosaic:TPU] Byte-granularity dynamic gathers
366a7dfe
superbobry [mosaic] Added a `k` prefix to `TPU_MemorySpace` members
7c432e95
apaszke [Mosaic GPU] Rework the CUDA_ROOT detection once again
2ec99812
apaszke [Mosaic GPU] Add support for s8 matmuls on Blackwell
b6575e19
hawkinsp Drop Python 3.10 support.
1cd076fd
bchetioui [Mosaic GPU] Implement canonicalization for `TiledLayout`s.
f562884c
Google-ML-Automation Merge pull request #29543 from hawkinsp:py310
cad4ba7e
yashk2810 Remove `_allow_deprecated_jit_signature` now that 0.6.2 is out and ne…
5e2afe6e
hawkinsp Add a cache around abstract_eval rules.
9cf81e4d
jakevdp Finalize a number of deprecations for JAX v0.7.0
36256960
Google-ML-Automation Merge pull request #29549 from jakevdp:finalize-deps
a7fe11d5
Google-ML-Automation [Mosaic TPU] Make the backward-compatibility libtpu condition stricter
ffabd3ea
hawkinsp Remove some dangling references from the docs.
4f437a3e
justinjfu [Mosaic GPU] Fix minor error in matmul test.
364f0049
Google-ML-Automation Merge pull request #29558 from hawkinsp:docs
98591573
hawkinsp Run pyupgrade --py311-plus.
59034e8e
yashk2810 Remove PositionalSharding from JAX now that 0.6.2 is out and next rel…
4aa1db4f
kanglant Skip pytest for Python 3.14 during the JAX release process
4f871b0e
jburnim Fix rare error with Literal in DynamicJaxprTracer.full_lower.
7c9613a6
Move jax._src.lax to its own BUILD rule
511bf2fe
dubstack Fix bugs in the double_buffered_pipeline example
4c54c022
Google-ML-Automation Merge pull request #29550 from hawkinsp:py311
09121436
kanglant Update jax requirements lock files after 0.6.2 release
cb2315a8
Google-ML-Automation Merge pull request #29554 from dubstack:patch-1
8cbd915b
jakevdp [doc] fix build error
6567192a
Google-ML-Automation Merge pull request #29563 from jakevdp:fix-doc
8337fe5b
justinjfu [Pallas][Mosaic GPU] Add GPU pipelining docs
5e290ddd
Google-ML-Automation Merge pull request #29560 from kanglant:update_lock_files
d3f08713
jakevdp Add wrap_negative_indices paramter to jnp.ndarray.at[]
a47ae573
subhankarshah [JAX] Skip failing tpu tests until June 30th.
9d1b01e0
Cjkkkk add cudnn sdpa mla support
d1883667
Google-ML-Automation Merge pull request #29434 from jakevdp:normalize-indices
00852217
Google-ML-Automation Merge pull request #28135 from justinjfu:gpu_pipe_docs
4d2b14a4
Google-ML-Automation Merge pull request #28872 from Cjkkkk:jax_cudnn_sdpa_mla
4efa56d2
ZixuanJiang Pass shardy option through jax config.
1e4a0f76
WillFroom Reenable AVX512 after LLVM fix upstream.
0b54a1e8
bchetioui [Mosaic GPU] Delete dead code in `layout_inference.py`.
e55f55fc
dimitar-asenov Create a test suite for Pallas mosaic GPU tests.
b99d004c
apaszke [Pallas:MGPU] Add docs for pl.core_map and plgpu.kernel
f99d2b49
bchetioui [Mosaic GPU] Change layout inference tests to rely on explicit `layou…
84066b77
bchetioui [Mosaic GPU][NFC] Add `checkInLayouts` and `checkOutLayouts` utils to…
e818940d
yashk2810 Remove `Layout`, `.layout`, `.input_layouts` and `.output_layouts` an…
bfc07e2f
bchetioui [Pallas/Mosaic GPU] Propagate transforms on the accumulator in `tcgen…
3fdc97b8
superbobry [pallas:mosaic] Fixed a typo in the distributed tutorial
346ce85d
superbobry [mosaic] `MemRef{Slice,Squeeze}` verifiers now support strided layouts
f3370cb5
Google-ML-Automation Add some tracemes in py_array to make slow device put debugging easier.
d46202e4
junwhanahn Add traceme to `PythonRefManager::CollectGarbage`
71ea45bb
rocm-repo-management-api-2 rocm-repo-management-api-2 requested a review 188 days ago
rocm-repo-management-api-2[bot] rocm-repo-management-api-2[bot] enabled auto-merge (rebase) 188 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
No one assigned
Labels
Milestone