jax
CI: 06/20/25 upstream sync
#477
Open
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
2162
Changes
View On
GitHub
CI: 06/20/25 upstream sync
#477
rocm-repo-management-api-2
wants to merge 2162 commits into
main
from
ci-upstream-sync-220_1
Add fast-path for non-concrete Tracers in is_constant_dim to lower tr…
795bbfc6
Reverts 5c33588b30edbae51d5b63b0bd7cc8d9058d7ccb
c1bb095c
[Mosaic GPU] Extract the type-related logic out of `reinterpret_smem_…
64ba9bc4
Port PartitionSpec to C++.
8478fe6e
Fix segfault if None is passed to PartitionSpec.__eq__.
e4c8da1e
[Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem…
09111c8e
Optimize jaxpr equation pretty-printing.
7dd721a8
[JAX] Allow registering callbacks to be called when backends are cleared
58a19372
PRNGKeyArray doesn't have a format field so assign the layout to be N…
ec61161d
Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
c9289ae1
Add committed property to MutableArray
567d61e2
[Pallas TPU] Add custom_vjp_call lowering rule
3248d55a
Small speedups to pretty-printing.
051386f0
Set `in_sharding` to UNSPECIFIED if a mutableArray is uncommitted whe…
0d1edcce
[Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matm…
6ae26141
[mutable-arrays] make partial_eval_jaxpr forward input-residuals
3699aa74
Merge pull request #29311 from mattjj:mutable-array-custom-vjp
0e9c0e4e
Clarify argument order for lax.associative_scan when reverse=True.
58faffd0
Update XLA dependency to use revision
1cb18ec2
fix vestigial change that caused breakage
cb2b2179
Update XLA dependency to use revision
0d1b1ef7
Port pretty-printer to C++.
c2a56909
Fix typo in error message.
d69086db
Minor fix to doc for random.orthogonal.
880dd131
[Pallas][Easy] Terser printing of GridMapping unless debug is set.
d8317b53
Don't trigger debug_infs in ndtri unless an inf is returned.
2281455a
[jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
a6d95ee3
Merge pull request #29321 from carlosgmartin:fix_random_orthogonal_doc
a572a997
Merge pull request #29313 from carlosgmartin:document_associative_sca…
74e362eb
Move jax/_src/api.py and associated files to their own BUILD rule
6d729fe7
Reverts b7833e94c1940ed475dae1f5e83e2a984cda5cea
92d5fe88
jax.nn.standardize: improve documentation
88de1e61
Merge pull request #29330 from dfm:ndtri-infs
52704b77
Merge pull request #29332 from jakevdp:standardize-doc
3157265e
Move jax/_src/custom_batching.py to its own build rule
42ea2ac3
Move jax/_src/earray.py to its own build rule
89e0c7ec
[array-api] pin array-api-tests to 2025.05.23
9651e608
Fix spelling error in the name of the input variable.
7c35300e
[Mosaic GPU] Error when causal masking is used on cuda versions known…
39de7159
Move jax/_src/ffi.py to its own build rule
9a591624
Move jax/_src/custom_partitioning.py to its own build rule
bfd0744f
Merge pull request #29344 from jakevdp:array-api-test-pin
a7c67e5a
Move jax/_src/buffer_callback.py to its own build rule
df50cd7e
[mutable-arrays] make custom_api_test.py pass with JAX_MUTABLE_ARRAY_…
31e99987
Merge pull request #29352 from mattjj:mutable-array-custom-vjp-fixes
5f2ce853
Move jax/_src/shard_alike.py to its own build rule
a58b27c6
[Pallas TPU][NFC] Use register to track buffer slots in pipeline loop
14aaa45f
add reference to pr-checklist
34cee968
Remove forward_compat check for alpha as it is past the support date.
b22be866
Update JAX test to not rely on ToString and instead check the Device …
28c31b8c
[pallas] Fix shard_map + Megacore in TPU interpret mode.
65c14a2c
Don't revisit shared subjaxprs in jaxpr_util.pprof_equation_profile.
c126ee3d
Automated Code Change
cc971e35
Update XLA dependency to use revision
b59a97f0
Merge pull request #28916 from gnecula:tf_version
297ee7c9
Improve batching for lax.platform_dependent
31ef2cf2
Move jax/_src/public_test_util.py to its own build rule
be23dcf2
[Pallas/Mosaic GPU] Fix the abstract eval rule for `load_p` in the pr…
f053dfed
Merge pull request #29257 from benquike:main
fdf7a2cb
Pass source_info to custom_staging_rules and into jaxpr inlining.
4846ed2f
[mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd
5d64c395
Merge pull request #29353 from mattjj:mutable-array-custom-vjp-scan2
56f3293c
Merge pull request #29358 from jenriver:documentation
14a9b590
[jax2tf] fix jax2tf sharding tests for shardy
b999fab1
[Mosaic GPU] Fix test after a previous PR changed the config params.
9e0472c1
Skip NumPy's `isClose` test for NumPy 2.3.0
c0541354
fix for a downstream breakage from #29353
03b01526
* Add support for output and input memory space colors in tpu custom …
89fca529
[JAX] Add `vma` to `ShapeDtypeStruct` constructor arguments.
bdd635a6
Add is_leaf_with_path predicate.
160e59f1
Pallas documentation fixes.
423aafe7
Rollback of #29353 due to downstream failures
ed03f383
Initial commit to make unreduced + AD work.
22ceb687
[JAX] Move the fallback of `colocated_cpu_devices` logic from the col…
d9e0244c
Add basic mutable array tests with AOT
fd6d90a7
Save a jaxpr equation in pl.cdiv if the rhs is an int.
f211c6b9
Update XLA dependency to use revision
cd4a0c6e
Ensure that all attributes are restored after pickling in `NamedShard…
02cdc7ba
[Mosaic GPU] Use _slice_smem also for barriers.
9da048ec
[pallas:mosaic] A few more primitives now have lowerings for all kern…
de82f9f4
[Mosaic GPU] Remove unneeded code.
827b855b
Propagate source_info in more places:
0edfb44f
Ensure that memory_kind is restored after pickling in SingleDeviceSha…
1cd49920
Don't recompute np.iinfo in _scalar_type_to_dtype.
cfebc498
Set explicit dot precision in the sparse solver test.
5f0e7e47
Migrated to mypy 1.16.0
225f0c62
Delete instantiate_const_abstracted.
f7f2ce5c
Merge pull request #29350 from jburnim:jburnim_interpret_shard_map_pl…
f519ad7c
add doc comment to vma in ShapedArray
6c9bcfc6
Do not call update_weak_type on the result of get_aval().
b87ea1c5
Move jax/_src/export to its own build rule
772cde81
Merge pull request #29294 from olupton:fix-scaled-matmul-stablehlo-test
703cf91b
Move materialization of NDIndexer out of draw()
2618d9b1
[JAX] Extend `colocated_cpu_devices` to accept `Mesh` besides devices
5c2a3201
extend pallas paged_attention with kv scales
9f8be25f
Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client.
e61ee7e2
Merge pull request #29405 from superbobry:union-fix
5b49b289
Merge pull request #29354 from rdyro:paged_attention_with_scales
b40d79c2
[XProf] Change tensorboard-plugin-profile to new xprof package
97ab9e0b
Add alternative location of `CUDA_ROOT` for Bazel build/tests with he…
004b5480
Merge pull request #29129 from Matt-Hurd:rename_to_xprof
d00b1cad
Add a pytype disable around zstandard.
3ef4db4d
Add execution to unreduced tests now that it works end-to-end
cc976cb0
Add nightly linux jax wheel tests for python 3.14.0b1
45e61d8c
Fix GPU quantized paged attention tests for < sm89
e7d252dc
Merge pull request #29423 from rdyro:gpu_paged_fix
2c467d6a
Merge pull request #29362 from gnecula:platform_index_linearize
1886a9e1
Update XLA dependency to use revision
4ad4d4a8
add jax.nn module type hints (__init__.pyi)
87ce7d1b
Move NamedSharding.__eq__ and NamedSharding.__hash__ into C++.
a2880363
[Mosaic GPU] Add conversion logic for `i4 -> f8e4m3fn`.
83c292bd
add missing dtypes to jax.numpy.__init__.pyi
ef106037
Temporarily disable AVX512 in linalg_test_cpu.
4e3bf29b
Add hermetic `nvshmem` dependencies to JAX targets.
294d86b2
Merge pull request #29425 from DanisNone:miss-dtypes
46b9eadd
[Pallas TPU] Support memory space constraints on pallas_call inputs.
f81f2589
[JAX] Update the example to use jax.numpy rather than numpy.
1228053f
Reland the C++ safe_zip implementation.
07193467
Add `all_gather_invariant` to lax.
6e2977d8
[Pallas TPU] Small fix to memory space constraints on pallas_call inp…
a9fdb768
[Mosaic GPU] Enable transpose tests in mosaic_gpu.
094b66fa
[doc] fix some inaccuracies in jnp.bincount docs
efbedc61
Add colorama back into test-requirements
abb756d4
Merge pull request #29441 from jakevdp:bincount-doc
109519a0
Use a frozenset for unconstrained_dims in sharding_constraint_p.
688c3d33
[jaxlib] Change Traceback to be a raw CPython class rather than a nan…
b72be575
Make the params of more jaxpr primitives hashable.
e04cc283
Remove unused internal optimization_barrier alias
d840447a
fix-forward for pallas tpu memory spaces test
382b3e02
Update XLA dependency to use revision
c86fefb5
Move jax._src.callback to its own BUILD rule
142ace2c
[Mosaic GPU] Convert all memrefs with transforms to unrealized casts …
604b6048
[Mosaic GPU] Add a Mosaic GPU op `with_transforms` for manually setti…
fc8192ca
[Mosaic GPU] Resolve different tile transforms using the largest comm…
a4f0e402
[Mosaic GPU] Use warpgroup semantics for the ragged dot example kernel.
70c90a95
Disable `too_slow` in data.draw() for test_ndindexer
9b45aacb
[Mosaic GPU] Reconcile the swizzle of the a and b operands for wgmma …
a3542f88
Add pjit_p to extend.core.primitives
7b7a5d8c
Fix return type annotation for tree_util.tree_broadcast.
e66a6dd2
[Mosaic GPU] Parametrize the `test_subview` test.
193f11db
Make params of several pallas primitives hashable.
670ae13d
Improve reshape not supported error message
92aedb28
Internal refactor: move TPU lowering rules out of jax/_src/lax
1b1e9f71
Make params of assert_consumed_value_p hashable.
64c95744
Merge pull request #29456 from j-towns:pjit-p-extend-primitives
ec637fe4
Merge pull request #28931 from olupton:cublas-cudnn
f41fc71c
Merge pull request #29420 from jakevdp:tpu-linalg
29eb8323
Make some remaining jaxpr equation params hashable.
790540c7
Implemented cross-host memory transfer on GPU.
b9cf0af5
Load CUDA libraries up front with cdll.LoadLibrary().
080294cf
Merge pull request #29462 from hawkinsp:nvrtc
a849d216
[Pallas] Add no_pipelining debugging option to emit_pipeline.
dfc90529
Replace `with_partitions` and `with_unreduced` with `.update` on Part…
8e346ddb
Remove `with_spec` from NamedSharding and replace with `.update`
7904b86b
[JAX] Fix the test names in colocated_python_test.py to following the…
f06888f9
Update XLA dependency to use revision
e28d6ed9
Make mosaic_gpu equation params hashable.
d065d2ac
Update XLA dependency to use revision
9678a764
Update XLA dependency to use revision
b25655cf
PR #28102: Add cudnn paged attention support in JAX cuDNN SDPA API
639216fa
Add `update_vma` and `update_weak_type` override on AbstractTMEMRef s…
53849a56
Add version guards to testAutoPgle
1362f7f5
Bump the libtpu check to 6/20
173574a9
Removed unused `PyTreeDef::MakeFromNodeDataAndChildren` and its Pytho…
0e7c96a5
[jax] Increase absolute test tolerance for lax_control_flow test
dbde6c44
Fix some more instances of unhashable jaxpr equation arguments.
49e52c08
[export] Add back-compat test for tridiagonal solve on GPU
e5af0880
Set heartbeat_timeout argument and flag.
a78c6a70
Install SciPy from its source (head) to test against Python 3.14.0b1
09d903fc
[doc] add missing axis_types documentation
7aec14f1
Remove legacy CPU custom calls.
8865ee66
Merge pull request #29497 from jakevdp:make-mesh-doc
35ae958b
Fix a missing bounds check in traceback code.
cb1cc379
Removed fixed suppressions
97e580c8
[Mosaic:TPU][NFC] Delete unused variable
748b39f4
[JAX] Relax the return type of `colocated_python` decorator
b6225124
Add custom-call ops to roofline.
2dfaceda
Removing Tensorflow references from the document.
0f5cdba0
Add test for programmatic tracing with options.
acf99e5e
Update XLA dependency to use revision
4be1402b
Pass through the `use_shardy_partitioner` with `jax.config.jax_use_sh…
93453f54
[Mosaic] Use BF16 ops for math::PowF on TPUv6+.
91651f80
Update Pallas debugging doc with TPU interpret mode + dynamic race de…
077f2b66
Prefer binaries in NVIDIA `nvcc` wheel over system CUDA installation …
124c723a
Add an API to overwrite the current execution_stream_id and respect i…
332aa354
jax.experimental.enable_x64: add warning to docstring
f22896ac
[Pallas TPU] Add flag to enable using registers to keep track of slot…
0fd08213
add psend and precv to jax/lax/parallel
784be1f9
[pallas] `AbstractMemoryRef` now implements all functional update met…
c2cc9f9c
Merge pull request #29504 from vfdev-5:tsan-ft-removed-fixed-suppression
3d37b0d7
Merge pull request #29516 from jakevdp:enable-x64-warning
353e7fac
Merge pull request #29410 from DanisNone:nn-type
dc9ef614
[Pallas][Mosaic GPU] Enable collective MMA from TMEM.
02688e18
Update XLA dependency to use revision
e4de90e6
Prepare for JAX release 0.6.2
8f81490a
Merge pull request #29135 from rosiezou:main
755bb676
Rollback https://github.com/jax-ml/jax/pull/29410 due to downstream p…
7dd13d70
Add `cum{logsumexp, min, max, prod, sum}` to JAX roofline.
c944c656
[JAX] Remove sleeping from colocated Python execution tests
19f34a06
Postrelease (0.6.2) changes
feab6f4b
Merge pull request #29528 from jax-ml:postrelease
8b88cc8e
Add `gather` to `roofline`.
34d88cc5
jaxlib_extension_version == 355 after 0.6.2 release. So remove the co…
1d077471
[Mosaic:TPU] Byte-granularity dynamic gathers
366a7dfe
[mosaic] Added a `k` prefix to `TPU_MemorySpace` members
7c432e95
[Mosaic GPU] Rework the CUDA_ROOT detection once again
2ec99812
[Mosaic GPU] Add support for s8 matmuls on Blackwell
b6575e19
Drop Python 3.10 support.
1cd076fd
[Mosaic GPU] Implement canonicalization for `TiledLayout`s.
f562884c
Merge pull request #29543 from hawkinsp:py310
cad4ba7e
Remove `_allow_deprecated_jit_signature` now that 0.6.2 is out and ne…
5e2afe6e
Add a cache around abstract_eval rules.
9cf81e4d
Finalize a number of deprecations for JAX v0.7.0
36256960
Merge pull request #29549 from jakevdp:finalize-deps
a7fe11d5
[Mosaic TPU] Make the backward-compatibility libtpu condition stricter
ffabd3ea
Remove some dangling references from the docs.
4f437a3e
[Mosaic GPU] Fix minor error in matmul test.
364f0049
Merge pull request #29558 from hawkinsp:docs
98591573
Run pyupgrade --py311-plus.
59034e8e
Remove PositionalSharding from JAX now that 0.6.2 is out and next rel…
4aa1db4f
Skip pytest for Python 3.14 during the JAX release process
4f871b0e
Fix rare error with Literal in DynamicJaxprTracer.full_lower.
7c9613a6
Move jax._src.lax to its own BUILD rule
511bf2fe
Fix bugs in the double_buffered_pipeline example
4c54c022
Merge pull request #29550 from hawkinsp:py311
09121436
Update jax requirements lock files after 0.6.2 release
cb2315a8
Merge pull request #29554 from dubstack:patch-1
8cbd915b
[doc] fix build error
6567192a
Merge pull request #29563 from jakevdp:fix-doc
8337fe5b
[Pallas][Mosaic GPU] Add GPU pipelining docs
5e290ddd
Merge pull request #29560 from kanglant:update_lock_files
d3f08713
Add wrap_negative_indices paramter to jnp.ndarray.at[]
a47ae573
[JAX] Skip failing tpu tests until June 30th.
9d1b01e0
add cudnn sdpa mla support
d1883667
Merge pull request #29434 from jakevdp:normalize-indices
00852217
Merge pull request #28135 from justinjfu:gpu_pipe_docs
4d2b14a4
Merge pull request #28872 from Cjkkkk:jax_cudnn_sdpa_mla
4efa56d2
Pass shardy option through jax config.
1e4a0f76
Reenable AVX512 after LLVM fix upstream.
0b54a1e8
[Mosaic GPU] Delete dead code in `layout_inference.py`.
e55f55fc
Create a test suite for Pallas mosaic GPU tests.
b99d004c
[Pallas:MGPU] Add docs for pl.core_map and plgpu.kernel
f99d2b49
[Mosaic GPU] Change layout inference tests to rely on explicit `layou…
84066b77
[Mosaic GPU][NFC] Add `checkInLayouts` and `checkOutLayouts` utils to…
e818940d
Remove `Layout`, `.layout`, `.input_layouts` and `.output_layouts` an…
bfc07e2f
[Pallas/Mosaic GPU] Propagate transforms on the accumulator in `tcgen…
3fdc97b8
[pallas:mosaic] Fixed a typo in the distributed tutorial
346ce85d
[mosaic] `MemRef{Slice,Squeeze}` verifiers now support strided layouts
f3370cb5
Add some tracemes in py_array to make slow device put debugging easier.
d46202e4
Add traceme to `PythonRefManager::CollectGarbage`
71ea45bb
rocm-repo-management-api-2
requested a review
188 days ago
rocm-repo-management-api-2[bot]
enabled auto-merge (rebase)
188 days ago
Login to write a write a comment.
Login via GitHub
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub