jax
CI: 03/04/25 upstream sync
#255
Merged
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
145
Changes
View On
GitHub
CI: 03/04/25 upstream sync
#255
charleshofer
merged 145 commits into
rocm-main
from
ci-upstream-sync-135_1
Scaled matmul for mxfp8
061d4acb
block_scale_config
332af587
Improve comments.
4a395956
Improve docstring.
2bb7f365
Rename custom-call name.
ae111f7c
Improve based on comment # 1
bfb9d3ca
Conditionally create mxfp8_configs.
08012e9c
refresh CI
d9456f36
Check "jax_rocm_visible_devices" at client creation.
6abc76c8
Fix CI
681ee184
Rename top level build file to BUILD.bazel.
525cb4bd
Allow query and keys that aren’t multiples of 128
a35494e0
Use LAPACK trsm kernel even for batched solves.
553b441f
Disable RBE on Windows
2f6f7221
Merge pull request #26327 from ksebaz:fix-rocm-with-distributed
03e2c888
Remove MemoryEffects annotations from async_{load/store} ops
cb7402f6
Use jax.Array as type annotation for pallas random keys
7c26ab53
Use the 64 core Windows runner to build artifacts
cf01fdfe
roofline: Support broadcasting, for binary ops.
08081c4d
roofline: Add support for min_p, max_p, reduce_sum_p.
aad178a6
Merge pull request #26676 from Rifur13:padding
467e0bdd
Disable `//tests:serialization_test_cpu` from TSAN job and remove `te…
dc1c3f9a
Use `${{ !cancelled() }}` instead of `${{ always() }}`
771306ba
Install uv to fix module not found error on Windows
f57c18ad
Add unfused_hbm usage to binary ops and dot_general.
e8543024
Use `uv` instead of `pip` for installing Python packages
7566daba
[Pallas] Add option for async DMAs in the new TPU interpret mode
4c7140fa
Improve after review # 2
17088e90
Update XLA dependency to use revision
f21eefe1
Fix failures in TSAN free threading CI.
33bbd5f1
[Pallas/Mosaic GPU][NFC] Move `thread_semantics` to `ModuleContext`.
7a34f1ce
[Pallas:MGPU] Don't recreate single_thread_predicate at every rule
3251b55e
[Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in…
1de2f839
[Mosaic GPU] Add support for warpgroup lowering of loops with vector …
99a12ef9
Remove code present to support jaxlib < 0.5.1.
66293d88
Upgrade to Bazel 7.4.1
c9c7250d
Merge pull request #26762 from hawkinsp:tsan
eb55aef5
Merge pull request #26712 from hawkinsp:ph3
d7849d5d
Fix build dependencies
8262987a
Add --system to uv commands in upstream-nightly workflow.
b8f236e6
Merge pull request #26740 from dfm:fix-upstream-nightly-uv
8b99ddc0
Implement jnp.ndarray.__contains__
7be7c489
Merge pull request #26779 from jakevdp:array-contains
f3fade3b
Add apache header.
7f0a5bc8
Extend random.orthogonal to semi-orthogonal matrices. Simplify initia…
ba428d8c
Enable resultstore logging
a65de524
Merge pull request #26291 from carlosgmartin:simplify_nn_initializers…
8492897f
Remove `tensorstore` dependency from `//jax/experimental/array_seria…
615219b1
Rolling back a commit that caused a 50-90% performance regression in …
7f9e7473
Redefine `is_fully_addressable` in shardings to support zero local de…
82124da5
Add support for TPU v5 2x2 tray configuration
0f0d5e90
[Pallas] Add name parameter to core_map
1ecbac97
Add swap as method to TransformedRef
b5fcffad
[Pallas TPU] Add support for GridDimensionSemantics to pallas_call
2646b8d4
Add `jax.copy_to_host_async(tree)`.
1becb57a
[pallas:triton] Generate more efficient code for loading contiguous s…
d6752e92
Fix cudnn version skipping in fused_attention_stablehlo_test
b3f7c93c
Update XLA dependency to use revision
0fbc453d
Fix a test failure under multi-threading.
6e736378
Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
07f5d7a4
Merge pull request #26804 from hawkinsp:tsan
a8738a06
Update `jax_wheel` target to produce both wheel and source distributi…
4eb782e4
#sdy close any partially sharded dimensions if using `auto` axes in a…
4997e457
Update the calculation for `num_processes` and `num_test_jobs` that a…
5ae0e58a
Change int4 packing from big-endian to little-endian
de4d0478
Add and test support for partitioning of batch dimensions in lax.linalg.
f93c2a1a
Add targets for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` edita…
401d3150
Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec.…
177e1f6e
Remove `_parsed_pspec` from everywhere in JAX except for NamedShardin…
034a827a
More cleanups around ParsedPartitionSpec. In a follow up CL, I can re…
d69da3b0
Comment change
da39b6f3
Use batched_device_put for token shard_arg handler
c94ec0eb
Add version check to jaxlib plugin imports.
c7ed1bd3
Disable certain tests on V4 and below.
3450e2ce
Add an allow_negative_indices option to lax.dynamic_slice and lax.dyn…
1e5d9a91
Remove parsed_pspec from NamedSharding constructor
c2655685
Add `-Wl,-undefined,dynamic_lookup` as linkopt on macOS
8bb31000
Merge _check_mesh_resource_axis and _check_axis_type_consistency into…
07f192cd
[Pallas TPU] Use grid_env for pipeline body so we can query num_progr…
6f57410e
Reverts 0f0d5e90ef1c3d60f35020141710ea350d17816b
6a773675
Merge pull request #26345 from wenscarl:scaled_matmul
c7ca35fe
Make sure default layout is None for input and output layout in all c…
dda62f57
Reduce pytest workers for asan to resolve memory usage causing OOM
d839e441
Temporarily skip linalg sharding tests on GPU.
810445b1
Update setup.py to automatically pick up libtpu patch releases
1b87ee07
Remove spurious zero_to_zero conversion used exclusively for backcomp…
d8953e53
[Mosaic GPU][NFC] Start refactoring the MMA parameter inference
092ea353
[Mosaic GPU] Remove TMA inputs
832f5a3a
Update XLA dependency to use revision
5a770701
[Mosaic GPU][NFC] Move some functions to a new file called `inference…
abfe2d08
Add linux python 3.13t nightly tests
55263ce4
[Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic s…
a9ab6141
[Mosaic GPU] Fix `as_dialect_barrier_memref` to take into account `Ba…
7c46480e
[Mosaic GPU][NFC] Delete workaround for dialect bindings before jaxli…
1bc36e62
[Mosaic GPU] Add support for small RHS tile sizes in WGMMA
bb96226d
Reimplement custom_vjp.optimize_remat using custom_dce.
bb9aed5e
Add build targets for `jax-rocm-plugin` and `jax-rocm-pjrt` wheels.
8f57b816
Merge pull request #26835 from skye:libtpu_version
6ee261e9
Merge pull request #26814 from dfm:remat-opt-custom-dce
10f6edeb
[sharding_in_types] `out_sharding` argument on einsum should only app…
da1cc0a5
Print the list of installed packages before running pytests
0ed42dcd
Bump oldest supported libtpu to match the compatibility window (12 we…
da7c90c4
Temporarily skip some more linalg sharding checks.
70024d22
doc: fix description of logsumexp axis
c56e794a
Add a profiler test for gpu run
48a55a6d
Improved errors when indexing with floats
b2c45b8e
Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
1f317663
Chnages for kernel export
2a1eeb0c
Update XLA dependency to use revision
e25caba9
`PRNGKeyArray.aval` should have the correct logical sharding. This re…
53494ade
Update XLA dependency to use revision
9d6bcd63
[Mosaic GPU] Wire up the `slice_lengths` and `indices` operands in lo…
c60ef5a2
[Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed…
3b305c66
Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
b8b690e5
Update XLA dependency to use revision
eee4d601
Use the same name for aliased Vars when pretty-printing Jaxprs.
a6c47d6f
[Mosaic GPU][NFC] Clean up the computation of group strides
3038348f
Merge pull request #26697 from gnecula:pp_aliased_var_names
bbadf990
[Mosaic GPU][NFC] Move the calculation of group strides into _validat…
11e6cfbc
[mgpu] Forach in tiled layout.
b9ebd918
#sdy support JAX export tests when Shardy is enabled.
ac493655
#sdy Add JAX backwards compatibility test.
ed4a7bba
[Mosaic GPU] Make the small WGMMA tile independent of transpose flags
e9f95cc3
Fix convolution example (kernel should be OIHW, not IOHW).
1a57fdf7
[Mosaic] Rename dep name.
5179642e
Fix linalg.norm to return zero for proper norms of empty matrices.
897e1a13
Fix wrong results in multidimensional pad.
7f05b74b
doc: in lax.cond, note that both branches will be traced
84ca80d2
Merge pull request #26895 from hawkinsp:pad2
439c412c
Remove the skip for test_output_streaming_inside_scan
07c4c03a
Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_ma…
07d1cd02
Merge pull request #26897 from jakevdp:cond-doc
4944dcb9
Merge pull request #26862 from jakevdp:logsumexp-docs
f9f47217
Merge pull request #26865 from jakevdp:fix-indexing-error
2c7043f6
[Pallas] Add experimental (private for now) API for manual fusion int…
0b6c3550
Add fuser to jax.experimental.pallas
d32e282f
[Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
00d9f452
Fix thread safety of JAX error checking
ea53c761
Update XLA dependency to use revision
96dce0b6
Update jnp.unique to support upstream interface changes.
6c5ef1a4
Merge pull request #26883 from dfm:np-unique-sorted
8906f281
[pallas:triton] Emit a better error message for matmul with non-2D op…
155839bb
[Mosaic GPU] Make sure to do the async proxy fence before wargroup sync
cdae5fcf
Remove redundant `BUILD_TAG` from `JAX` wheels build rule.
ce3412e5
Merge branch 'rocm-main' into ci-upstream-sync-135_1
07cd809b
charleshofer
requested a review
314 days ago
charleshofer
merged
d6cf755f
into rocm-main
314 days ago
charleshofer
deleted the ci-upstream-sync-135_1 branch
314 days ago
Login to write a write a comment.
Login via GitHub
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub