jax
CI: 04/30/25 upstream sync
#392
Merged
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
900
Changes
View On
GitHub
CI: 04/30/25 upstream sync
#392
charleshofer
merged 900 commits into
rocm-main
from
ci-upstream-sync-183_1
rocm-repo-management-api-2
requested a review
240 days ago
rocm-repo-management-api-2[bot]
enabled auto-merge (rebase)
240 days ago
charleshofer
force pushed
from
806190d0
to
35fa98b3
240 days ago
disabled auto-merge
240 days ago
Manually disabled by user
charleshofer
enabled auto-merge (rebase)
240 days ago
charleshofer
approved these changes on 2025-04-30
disabled auto-merge
240 days ago
Rebase failed
JAX release v0.6.0
9e9a3f03
Update version numbers after release.
f7e3fe62
Remove _manual_axes from NamedSharding since we can now track the man…
46fb4001
Remove jax_varying_axes_in_types config and `rewrite` from `shard_map_p`
85eb58a1
Use llvm::cast/dyn_cast/isa since alternatives are deprecated in http…
4254524f
Added `lax.axis_size` and switched all existing usage of `psum(1, ...…
f5101ef7
Do not use `-> ...`
2bf0ca19
Automated Code Change
fbfaa0d1
Suppress race between split_keys_entry_added and dict_dict_merge in 3…
b55a4e94
Update XLA dependency to use revision
1770096e
Remove shard_map rewrite machinery since we have `vma` in types by de…
1fec3941
Make `ShardMapTracer` track `vma` instead of `rep`.
0ca67825
Exclude running tests against the oldest supported libtpu for releases
7f709323
Use _make_lengths_same for explicit mode too.
4f5ed5a2
[Pallas] Fix potential race condition in Pallas TPU docs
43287e37
Enable cross-compilation builds of the wheels.
8d8385ad
Remove unused jax_spmd_mode flag.
c9024ae0
[pallas:mosaic] Replaced `device_type=` with `kernel_type` in `TPUCom…
7774297a
Make Clang use manylinux C++ standard library
9fbbffb3
Remove roofline subtests so they work with pytest.
06251b38
Don't recompute source_info for each tracer during staging.
e6ac43a0
Return empty `wheel sources map` in case if `wheel_sources` is `None`.
e2e2bd1e
Remove code to support jaxlib < v0.6.
d100df27
Enter into auto mode for `.at[...].get(...)` a bit earlier so that al…
ffae4a29
Rename `with_user_mesh` to `with_explicit_mesh`
b93f0282
[Pallas] Generalize BlockSpec to support different indexing mode for …
9558ca0d
[Mosaic TPU] Enable non-sublane-aligned 2D int8 load/stores
cb04ff56
Update XLA dependency to use revision
6a811093
Remove unused type forwarding declarations for jax.lib.xla_client.{Sh…
98259c61
Reverts c2ba1790417ca206a4d88b25aef4d5ae510dd717
f66d71e4
Allow the CPU collective implementation to be overridden to None.
5f57c666
[JAX] Remove jax.lib.xla_client.mlir_api_version and its uses.
59b155ac
Don't use the deprecated jax.dlpack.to_dlpack in tests.
ea3557c9
Drop into `Auto` mode for `.at[...].set(...)` but instead of taking a…
4fb72216
Mark CliDebuggerTest as a thread-unsafe test.
1f865a0b
Ensure outputs are tracers when inlining jit.
a4d2e28e
#sdy Do not close any partially sharded dimensions if using auto axes…
3d2c2dd1
Handle `sharding` param in convert_element_type's batching rule prope…
30e3809d
Makes Effort_02 the default value for memory_fitting_level.
67c4e75c
Refactor array serialization into separate JAX and tensorstore logic
b8c3542d
Add jax.fwd_and_bwd
1144cb03
Update changelog to add information about breaking change with respec…
7934a4ba
Make `Shape::add_dimensions()` validate arguments by default.
0b8b0e2f
Update XLA dependency to use revision
76124057
fix BUILD file syntax error
3d918f20
[Pallas/Fuser] Add support for pl.Element in fuser BlockSpec
d587b76a
Update XLA dependency to use revision
29d743e6
Add support for TPU7x in Mosaic.
87f086e5
Update XLA dependency to use revision
ca9f3d93
fix: typo
a36a8f71
Fixed cached mlir module mutation issue in export
410d56e5
jax.random.bernoulli: add mode='high' for improved sampling for small p
c187b5cc
Simplify (and potentially accelerate) fftfreq(), as a single vectoriz…
cf56cb80
Allow all reshapes if the operand is fully replicated
4a4a0f05
[docs] fix typo in doc link
ab2516db
implement _replace_with for PRNGKeyArray
844187ff
Switch JAX to the new `ProgramShape::AddParameter()` API.
1193da06
[ragged-paged-attn] Add auto-tuned table that is being used for vLLM
40eb37f9
[Pallas][jax] Better error message for unexpected types in standard a…
f24595cd
[JAX] Use `xla::ifrt::Client::MakeArraysFromHostBufferShards()` in Ar…
7688ee7e
[Mosaic GPU] Align on device profiler array in smem to 8 bytes
374a5edd
#sdy Add more debug info when there is a mesh mismatch in JAX export
42811dcb
[mosaic_gpu] Fixed the export code path in `_mosaic_gpu_lowering_rule`
b877daaa
Update XLA dependency to use revision
8c675bca
jnp.frexp: add custom JVP rule for proper derivatives.
fc3dba80
Add support for generating freethreaded pip lockfiles.
18a5f05f
[Pallas] Support the new TPU interpret mode with pl.core_map.
af4cd548
Fix ensure_compile_time_eval context after stackless.
81dd0b47
BUG: return one covariance matrix per `rhs` in `polyfit`
1cfa3894
typing: improve annotations for scatter implementations
8347142b
[Pallas][Mosaic TPU] Improve error message for 1D block specs when th…
445fb63b
[CI] Add tpu v6e-8 to nightly test and release test.
6809c690
[Pallas] Update TPU pipelining docs
aca090d7
[Pallas Fuser] Allow multiple BlockSpec inputs to select_n push rule …
56e12fd6
Fix handling of SymbolicZero output when batching custom_jvp.
26691f39
[Mosaic GPU] Do not create a barrier with `arrival_count == 0` in the…
3a5504e5
[Pallas Fuser] Change physicalize to resolve_fusion_dtypes
755265d5
[Pallas] Propagate Jaxpr effects through pl.fusible
f060fcc6
Set core_index to default for tpu_pallas_async_test
ebb21ee5
jnp.ldexp: avoid overflow for large exponent
541c1169
Update tsan requirements patch after lockfile update.
0dccb0f5
Initial commit of non-experimental `shard_map`.
d3e19fea
[ragged-paged-attn] Autotune set up and increase page size to avoid S…
6321af49
DOC: link to ai-stack tutorials from JAX's front page
e0295f7d
Skip unary_ops_accuracy test for TPU version 7 and above.
36a6c12b
[Mosaic:TPU][Relayout] Row shifts for packed types and non-native til…
06390001
Account for versioned clang binaries
5253e99f
Expose `jax.shard_map` as the public non-experimental API and move sh…
330f414a
[Pallas] Make Pallas PRNG more robust by improving the Mosaic seed br…
d35c72df
Make `set_mesh` thread local so that it behaves exactly like `use_mesh`
fbbbb7b6
[pallas:mosaic_gpu] Added an export backward compatibility test for a…
069409c3
Replace reference to jax.readthedocs.io with docs.jax.dev in jax._src…
d62d667b
Include MGPU in grid/blockspec tutorial + use proper note/warning for…
26a5bce9
jnp.ldexp: avoid overflow for large exponent
541c1169
Initial commit of non-experimental `shard_map`.
d3e19fea
[ragged-paged-attn] Autotune set up and increase page size to avoid S…
6321af49
DOC: link to ai-stack tutorials from JAX's front page
e0295f7d
Skip unary_ops_accuracy test for TPU version 7 and above.
36a6c12b
Account for versioned clang binaries
5253e99f
Expose `jax.shard_map` as the public non-experimental API and move sh…
330f414a
[Pallas] Make Pallas PRNG more robust by improving the Mosaic seed br…
d35c72df
Make `set_mesh` thread local so that it behaves exactly like `use_mesh`
fbbbb7b6
Replace reference to jax.readthedocs.io with docs.jax.dev in jax._src…
d62d667b
Update XLA dependency to use revision
fa4af688
Update install docs for AMDGPU
f58f0753
Inline literals while tracing instead of in a separate pass.
58ea6785
Handle DShapedArray in caller.
830ffaa7
Fix typo "dataclasss".
964e8292
Fix some typos.
7880c58d
Fix mypy errors related to jaxlib.
4e950d9c
Cleanup: don't redefine softmax in jax.random
4a5375f6
[Pallas/Mosaic GPU] Allow explicit `smem` aliasing.
acedda03
Add a pvary on the `ones` we create in grad because the `ans` in the …
4853aa47
Add a document for activation and parameter offloading
ab5f6706
TSAN FT CI, install cython from nightly wheels
2b67b740
Move contents of jaxlib/xla into jaxlib/
2a245036
Reverts 49e25c6167806ba90efe8370fb04db3f4966437c
9923ba80
Automatically initialize distributed runs in K8s indexed jobs
4bb2e2a5
Skip host_offloading notebook from running
c08fd51f
jax.numpy: make standard input utilities respect __jax_array__
466d8c49
internal change
b554a65e
Set is_custom field for custom_partitioning_sharding_rule so that rul…
8f82928b
Remove config `_check_rep` to `_check_vma` and the kwarg in shard_map…
8f87dc19
[pallas] `pl.pallas_call` no longer allows `compiler_params=` to be a…
811342ca
Migrate jax.experimental.shard_map to jax.shard_map in internal JAX a…
57c5175e
[pallas] Added a note on the recent `compiler_params=` change to the …
44400ff5
[pallas] Fixed the type of `MemoryRef.dtype`
bba6c56f
[Mosaic GPU] Introduce a dedicated `DialectBarrierRef` and handle war…
4df5a5e8
Update XLA dependency to use revision
2a535e45
Update FFI tutorial to new shard_map.
a29196af
Change pallas/distributed.{ipynb,md} to use jax.shard_map
f25bed71
Fix shardy sharding rule for SVD decomposition.
fc62fa19
#sdy avoid an extra call to `mlir.module_to_bytecode` and use `mlir::…
5f08cd04
Refactor LocalMask to inherit from _ComputableMask
fa56b7b3
[pallas:mosaic] Fixed a bug in `pl.debug_print` lowering
c1630e1d
Add `__str__` to `UnshapedArray` so that whenever we `print(aval)`, w…
6cde5323
Ran pyupgrade --py310-plus on .pyi files in jaxlib
9e937fc0
Make inspect_array_sharding inside shard_map work with `check_vma=Tru…
d1523105
Cleanup: remove superfluous jax.numpy utility
34d8d4ff
Remove PositionalSharding from JAX from almost all places except for …
718c103a
[pallas] Removed the deprecated `jax.experimental.pallas.gpu` module
531c5d4c
[Mosaic GPU] Adding a deterministic backwards pass to the pallas MGPU…
a0dfbcd2
Re-land: "Don't recompute source_info for each tracer during staging".
d1cb3016
Update README.md to remove jax.pmap
ce4c923c
use :check: and :x: emojis in readme table
5a1be75b
Added new multi-controller JAX guide.
e546ad98
Update the sharding table to be the same everywhere
bc601c40
fix: in jax.asarray check default_device instead of default_backend
bf5ec8cd
Set __module__ on NamedSharding, SingleDeviceSharding and PmapShardin…
f4debd71
Add mosaic tests to optional B200 GPU presubmit.
76dc1e7a
[readme] include teaser example of shardings
b4ea9e91
Move sharding table above the code example
28716264
Updated TSAN suppresssions files to get traceback of crashed process
711e26c1
fix: make build.py use /usr/bin/env python3 as shebang
a10ad049
Fix typo in Bazel command for Mosaic tests.
a186d792
[Pallas/Fuser] Bugfix for broadcasting, lax.slice_p, and lax.dynamic_…
3a451fff
Update `debug_info` tests for `use_direct_linearize`.
6ce2bb55
[scan] when a carry is read-only, move it to be a const
f7611ef5
Fix shard_map's direct linearize after vma has been turned on
aea756d7
[pallas] Use `*MemorySpace` aliases
7e7e3e5a
[Pallas TPU] Introduce a BoundedSlice block shape type
3b155efa
[JAX:Sparse] Add BCSR benchmarks to sparse_benchmarks.py.
537b8754
[Mosaic GPU] Fix up the Blackwell code to match changes in the CUDA r…
582ab0c1
[Mosaic TPU] Allow indexing refs with narrow integers
923364ea
[Mosaic TPU] Allow simultaneous column and row shifts
2e5fa04d
Update XLA dependency to use revision
526ccbd9
[Mosaic TPU] Relax test restrictions after improving DMA support for …
5baa2a5c
#jax #sdy add a Shardy config to mock_gpu_topology_test.
953379e3
[Pallas] Remove leftover debug=True in Pallas tests
94bf5d0c
[JAX] Remove xla_client_test.
7dd98cb4
[CI] Correct the optional GPU presubmit to work properly with workflo…
24543a37
#jax #sdy pass the value of `use_shardy_partitioner` to `get_compile_…
abb8999f
[Mosaic GPU] Wire up the plugin that contains the MGPU custom call in…
85b7e8fc
Make callbacks work in unrolled loops on TPU.
6156b6fa
[XLA:Python] [JAX] Change JAX to use the _profiler module defined in …
d300eb87
Raise a better error in `jax.linearize` when vma's don't match betwee…
af011cf4
Allow decorator factory pattern for `jax.shard_map` i.e. `@jax.shard_…
0e47af50
change build target
6bb73c22
roll back #28210 because it broke an internal test
56667fd6
[Mosaic] Support bf16 select and cmp <= TPUv4.
e80c4222
jax.numpy: move linspace & friends into array_creation
3a7728f2
Minor optim in profiler on session stop and export
4ef4c1ee
Replace auto tracking with manual axis names tracking internally in s…
c5e43670
Edit print_environment_info to print environment variables that start…
9b38623e
Fix wrong shard_map import
01ddb2f9
Add conv_general_dilated sharding rule
0b7befdf
[Pallas] Fix index_map equality checks
9434f45c
Remove `rep` completely from shard_map (rip rep from shmap -- RIP). T…
d1e0e58b
Add multiaccelerator H100 tests to optional GPU presubmit.
90e985c7
Reverts fbb0cbbc6bfe30941076a365a427e12efbb5253b
8a72e754
[jax] Switch lapack kernels to builtin FFI attrs decoding
07ca41f5
[XLA:Python] [JAX] Move XlaBuilder bindings out of JAX and into XLA:P…
b3e42331
[JAX] [XLA:Python] Move ShapeIndex bindings out of JAX and into XLA.
2fb49804
Remove PositionalSharding from JAX from almost all places except for …
718c103a
[pallas] Removed the deprecated `jax.experimental.pallas.gpu` module
531c5d4c
[Mosaic GPU] Adding a deterministic backwards pass to the pallas MGPU…
a0dfbcd2
Re-land: "Don't recompute source_info for each tracer during staging".
d1cb3016
Update README.md to remove jax.pmap
ce4c923c
use :check: and :x: emojis in readme table
5a1be75b
Added new multi-controller JAX guide.
e546ad98
Update the sharding table to be the same everywhere
bc601c40
fix: in jax.asarray check default_device instead of default_backend
bf5ec8cd
Set __module__ on NamedSharding, SingleDeviceSharding and PmapShardin…
f4debd71
Add mosaic tests to optional B200 GPU presubmit.
76dc1e7a
[readme] include teaser example of shardings
b4ea9e91
Updated TSAN suppresssions files to get traceback of crashed process
711e26c1
fix: make build.py use /usr/bin/env python3 as shebang
a10ad049
Fix typo in Bazel command for Mosaic tests.
a186d792
[Pallas/Fuser] Bugfix for broadcasting, lax.slice_p, and lax.dynamic_…
3a451fff
Update `debug_info` tests for `use_direct_linearize`.
6ce2bb55
[scan] when a carry is read-only, move it to be a const
f7611ef5
Fix shard_map's direct linearize after vma has been turned on
aea756d7
[pallas] Use `*MemorySpace` aliases
7e7e3e5a
[Pallas TPU] Introduce a BoundedSlice block shape type
3b155efa
[JAX:Sparse] Add BCSR benchmarks to sparse_benchmarks.py.
537b8754
[Mosaic GPU] Fix up the Blackwell code to match changes in the CUDA r…
582ab0c1
[Mosaic TPU] Allow indexing refs with narrow integers
923364ea
Update XLA dependency to use revision
526ccbd9
[Mosaic TPU] Relax test restrictions after improving DMA support for …
5baa2a5c
#jax #sdy add a Shardy config to mock_gpu_topology_test.
953379e3
[Pallas] Remove leftover debug=True in Pallas tests
94bf5d0c
[JAX] Remove xla_client_test.
7dd98cb4
#jax #sdy pass the value of `use_shardy_partitioner` to `get_compile_…
abb8999f
[Mosaic GPU] Wire up the plugin that contains the MGPU custom call in…
85b7e8fc
Make callbacks work in unrolled loops on TPU.
6156b6fa
[XLA:Python] [JAX] Change JAX to use the _profiler module defined in …
d300eb87
Raise a better error in `jax.linearize` when vma's don't match betwee…
af011cf4
change build target
6bb73c22
roll back #28210 because it broke an internal test
56667fd6
[Mosaic] Support bf16 select and cmp <= TPUv4.
e80c4222
jax.numpy: move linspace & friends into array_creation
3a7728f2
Minor optim in profiler on session stop and export
4ef4c1ee
Replace auto tracking with manual axis names tracking internally in s…
c5e43670
Edit print_environment_info to print environment variables that start…
9b38623e
[Pallas] Fix index_map equality checks
9434f45c
Remove `rep` completely from shard_map (rip rep from shmap -- RIP). T…
d1e0e58b
Add multiaccelerator H100 tests to optional GPU presubmit.
90e985c7
Reverts fbb0cbbc6bfe30941076a365a427e12efbb5253b
8a72e754
[XLA:Python] [JAX] Move XlaBuilder bindings out of JAX and into XLA:P…
b3e42331
Update XLA dependency to use revision
94a843a3
Update XLA dependency to use revision
af568312
[Mosaic GPU] Add a `BroadcastInDim` op in the mosaic mlir dialect.
906cb417
[Mosaic GPU] Handle WGMMA_ROW_LAYOUT in `vector_store` lowering and i…
6d49c8f5
[Mosaic GPU] Delete `gpu.binary` after lowering.
8138dccf
[Mosaic GPU] Make sure the tests all pass on B200 (esp. in our CI)
4cbef87e
[Pallas:MGPU] Increase shard_coaunt for mgpu_attention_test to avoid …
404f7bc8
[Mosaic GPU] Relax tolerance for WGMMATest.test_narrow_n
c52ade7b
[Mosaic GPU] Fix a skip condition to avoid problems with nightly jaxl…
a835750d
Fix formatting error in matrix exclude strategy
d7d545b3
implement hyp2f1
d8cdedb2
Update `build.py` to avoid duplication of bazel options both in bazel…
ca971c34
[Mosaic] Remove Python pipeline.
82597a9e
Relax constraints in jnp.vectorize for output shapes with default sig…
4523d3ab
[JAX:sparse] Fix bcsr.from_bcoo to use the index_dtype of the input B…
34773fdc
Fix warning in mosaic/pipeline.py under Python 3.12.
6d6ae50c
Ensure __jax_array__ works properly with JIT disabled.
89ff6932
Allow 2x2x2 topologies with v6e
e3b45ad5
Fix scaled_matmul_stablehlo_test in x64 mode
4abcfb9e
Handle extended dtypes in FFI lowering.
21745c1a
[Mosaic TPU] Add support for narrow integer `arith.constant`s in kernels
5d837dd0
Add sharding devices to XlaCompileOptions and plumb them through from…
90ab27d6
Add `standard_insert_pvary` support to `reduce`.
57d7a117
[JAX] Add a python 3.14 requirement lock file and update WORKSPACE
1a7ce898
Bump actions/setup-python from 5.5.0 to 5.6.0
3cbc5a3e
[jax/docs] - Fix link in quickstart.md
57a6748e
Disable logsumexp test under SciPy 1.15.
42aa3ebc
chore: load py_library from rules_python
f0d02bf7
Change nightly installation instructions and CI to use new package in…
4ac29b92
Improving PyTorch data loading notebook
b2400141
Disable old JAX nightly builds
7eb3f5bb
Small changes for easier testing of kernels with TPU interpret mode.
58e75bf9
Update JAX nightly index usage
a1fe4ca5
Enter into correct mesh context in shard_map in `_partial_eval_jaxpr_…
727fe84c
Add GetAllocatorStats() method for device.
ba029782
[Mosaic GPU] Add layout inference and lowering for `vector.MultiDimRe…
e949ab4c
[Mosaic GPU] Add suport for packed TMEM
fdc4209f
Add note to changelog about nightly package switchover
1d76b6d0
Allow nightly workflows to halt for connection
80e89e86
Update XLA dependency to use revision
65f0e17b
Re-land "Inline literals while tracing instead of in a separate pass".
f0a479e1
Remove unused parameters in emit_python_callback.
641788a3
Remove deprecated GPU linalg kernels after compatibility period.
5f15f770
fix: use backend to call xb.process_count in _raise_warnings_or_error…
7472ea3f
Fix deprecation warnings in cudnn scaled matmul.
4b21c729
Support batch_axis being int rather than Sequence[int] in initializers.
af6d3d70
[JAX] Remove jaxlib/xla directory.
0f2a89dd
SPMD with jax.jit, namedsharding and partition spec
ad051116
Fix test failure in autodidax due to compile() API change.
af336765
PR #28391: Add version guard in autodidax test.
88d6a8fc
Update autodidax.ipynb and autodidax.md
aa7bdd67
Show Literal's aval when pretty printing.
cbe2b203
Relax test tolerance for PureCallbackTest.test_can_take_grad_of_pure_…
6d53a573
Set PYTHONWARNINGS=error for jax_multiplatform_test.
e7fd33e3
chore: load py_library from rules_python
f0d02bf7
[Mosaic GPU] Add layout inference and lowering for `vector.MultiDimRe…
e949ab4c
[Mosaic GPU] Add suport for packed TMEM
fdc4209f
Re-land "Inline literals while tracing instead of in a separate pass".
f0a479e1
Round-robin pytest tests across GPUs.
f8b77865
[Pallas TPU] Add better support for BoundedSlice and other BlockDim t…
ff8341d6
charleshofer
force pushed
from
e9c8fdcb
to
7c833647
239 days ago
charleshofer
merged
7c833647
into rocm-main
239 days ago
Login to write a write a comment.
Login via GitHub
Reviewers
charleshofer
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub