jax
CI: 03/04/25 upstream sync
#255
Merged

CI: 03/04/25 upstream sync #255

charleshofer merged 145 commits into rocm-main from ci-upstream-sync-135_1
charleshofer
wenscarl Scaled matmul for mxfp8
061d4acb
wenscarl block_scale_config
332af587
wenscarl Improve comments.
4a395956
wenscarl Improve docstring.
2bb7f365
wenscarl Rename custom-call name.
ae111f7c
wenscarl Improve based on comment # 1
bfb9d3ca
wenscarl Conditionally create mxfp8_configs.
08012e9c
wenscarl refresh CI
d9456f36
Check "jax_rocm_visible_devices" at client creation.
6abc76c8
wenscarl Fix CI
681ee184
dfm Rename top level build file to BUILD.bazel.
525cb4bd
Rifur13 Allow query and keys that aren’t multiples of 128
a35494e0
dfm Use LAPACK trsm kernel even for batched solves.
553b441f
nitins17 Disable RBE on Windows
2f6f7221
Google-ML-Automation Merge pull request #26327 from ksebaz:fix-rocm-with-distributed
03e2c888
apaszke Remove MemoryEffects annotations from async_{load/store} ops
cb7402f6
Google-ML-Automation Use jax.Array as type annotation for pallas random keys
7c26ab53
nitins17 Use the 64 core Windows runner to build artifacts
cf01fdfe
matthiaskramm roofline: Support broadcasting, for binary ops.
08081c4d
matthiaskramm roofline: Add support for min_p, max_p, reduce_sum_p.
aad178a6
Google-ML-Automation Merge pull request #26676 from Rifur13:padding
467e0bdd
Google-ML-Automation Disable `//tests:serialization_test_cpu` from TSAN job and remove `te…
dc1c3f9a
nitins17 Use `${{ !cancelled() }}` instead of `${{ always() }}`
771306ba
nitins17 Install uv to fix module not found error on Windows
f57c18ad
matthiaskramm Add unfused_hbm usage to binary ops and dot_general.
e8543024
nitins17 Use `uv` instead of `pip` for installing Python packages
7566daba
jburnim [Pallas] Add option for async DMAs in the new TPU interpret mode
4c7140fa
wenscarl Improve after review # 2
17088e90
Google-ML-Automation Update XLA dependency to use revision
f21eefe1
hawkinsp Fix failures in TSAN free threading CI.
33bbd5f1
bchetioui [Pallas/Mosaic GPU][NFC] Move `thread_semantics` to `ModuleContext`.
7a34f1ce
apaszke [Pallas:MGPU] Don't recreate single_thread_predicate at every rule
3251b55e
apaszke [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in…
1de2f839
apaszke [Mosaic GPU] Add support for warpgroup lowering of loops with vector …
99a12ef9
hawkinsp Remove code present to support jaxlib < 0.5.1.
66293d88
Google-ML-Automation Upgrade to Bazel 7.4.1
c9c7250d
Google-ML-Automation Merge pull request #26762 from hawkinsp:tsan
eb55aef5
Google-ML-Automation Merge pull request #26712 from hawkinsp:ph3
d7849d5d
wsmoses Fix build dependencies
8262987a
dfm Add --system to uv commands in upstream-nightly workflow.
b8f236e6
Google-ML-Automation Merge pull request #26740 from dfm:fix-upstream-nightly-uv
8b99ddc0
jakevdp Implement jnp.ndarray.__contains__
7be7c489
Google-ML-Automation Merge pull request #26779 from jakevdp:array-contains
f3fade3b
wenscarl Add apache header.
7f0a5bc8
carlosgmartin Extend random.orthogonal to semi-orthogonal matrices. Simplify initia…
ba428d8c
nitins17 Enable resultstore logging
a65de524
Google-ML-Automation Merge pull request #26291 from carlosgmartin:simplify_nn_initializers…
8492897f
Google-ML-Automation Remove `tensorstore` dependency from `//jax/experimental/array_seria…
615219b1
emilyfertig Rolling back a commit that caused a 50-90% performance regression in …
7f9e7473
emilyfertig Redefine `is_fully_addressable` in shardings to support zero local de…
82124da5
sharadmv Add support for TPU v5 2x2 tray configuration
0f0d5e90
sharadmv [Pallas] Add name parameter to core_map
1ecbac97
sharadmv Add swap as method to TransformedRef
b5fcffad
sharadmv [Pallas TPU] Add support for GridDimensionSemantics to pallas_call
2646b8d4
tomhennigan Add `jax.copy_to_host_async(tree)`.
1becb57a
chr1sj0nes [pallas:triton] Generate more efficient code for loading contiguous s…
d6752e92
beckerhe Fix cudnn version skipping in fused_attention_stablehlo_test
b3f7c93c
Google-ML-Automation Update XLA dependency to use revision
0fbc453d
hawkinsp Fix a test failure under multi-threading.
6e736378
Google-ML-Automation Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
07f5d7a4
Google-ML-Automation Merge pull request #26804 from hawkinsp:tsan
a8738a06
Google-ML-Automation Update `jax_wheel` target to produce both wheel and source distributi…
4eb782e4
bartchr808 #sdy close any partially sharded dimensions if using `auto` axes in a…
4997e457
nitins17 Update the calculation for `num_processes` and `num_test_jobs` that a…
5ae0e58a
akuegel Change int4 packing from big-endian to little-endian
de4d0478
dfm Add and test support for partitioning of batch dimensions in lax.linalg.
f93c2a1a
Google-ML-Automation Add targets for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` edita…
401d3150
yashk2810 Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec.…
177e1f6e
yashk2810 Remove `_parsed_pspec` from everywhere in JAX except for NamedShardin…
034a827a
yashk2810 More cleanups around ParsedPartitionSpec. In a follow up CL, I can re…
d69da3b0
Google-ML-Automation Comment change
da39b6f3
yashk2810 Use batched_device_put for token shard_arg handler
c94ec0eb
dfm Add version check to jaxlib plugin imports.
c7ed1bd3
Google-ML-Automation Disable certain tests on V4 and below.
3450e2ce
hawkinsp Add an allow_negative_indices option to lax.dynamic_slice and lax.dyn…
1e5d9a91
yashk2810 Remove parsed_pspec from NamedSharding constructor
c2655685
kanglant Add `-Wl,-undefined,dynamic_lookup` as linkopt on macOS
8bb31000
yashk2810 Merge _check_mesh_resource_axis and _check_axis_type_consistency into…
07f192cd
sharadmv [Pallas TPU] Use grid_env for pipeline body so we can query num_progr…
6f57410e
Google-ML-Automation Reverts 0f0d5e90ef1c3d60f35020141710ea350d17816b
6a773675
Google-ML-Automation Merge pull request #26345 from wenscarl:scaled_matmul
c7ca35fe
yashk2810 Make sure default layout is None for input and output layout in all c…
dda62f57
kanglant Reduce pytest workers for asan to resolve memory usage causing OOM
d839e441
dfm Temporarily skip linalg sharding tests on GPU.
810445b1
skye Update setup.py to automatically pick up libtpu patch releases
1b87ee07
Google-ML-Automation Remove spurious zero_to_zero conversion used exclusively for backcomp…
d8953e53
apaszke [Mosaic GPU][NFC] Start refactoring the MMA parameter inference
092ea353
apaszke [Mosaic GPU] Remove TMA inputs
832f5a3a
Google-ML-Automation Update XLA dependency to use revision
5a770701
bchetioui [Mosaic GPU][NFC] Move some functions to a new file called `inference…
abfe2d08
kanglant Add linux python 3.13t nightly tests
55263ce4
bchetioui [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic s…
a9ab6141
bchetioui [Mosaic GPU] Fix `as_dialect_barrier_memref` to take into account `Ba…
7c46480e
bchetioui [Mosaic GPU][NFC] Delete workaround for dialect bindings before jaxli…
1bc36e62
apaszke [Mosaic GPU] Add support for small RHS tile sizes in WGMMA
bb96226d
dfm Reimplement custom_vjp.optimize_remat using custom_dce.
bb9aed5e
Google-ML-Automation Add build targets for `jax-rocm-plugin` and `jax-rocm-pjrt` wheels.
8f57b816
Google-ML-Automation Merge pull request #26835 from skye:libtpu_version
6ee261e9
Google-ML-Automation Merge pull request #26814 from dfm:remat-opt-custom-dce
10f6edeb
yashk2810 [sharding_in_types] `out_sharding` argument on einsum should only app…
da1cc0a5
nitins17 Print the list of installed packages before running pytests
0ed42dcd
kanglant Bump oldest supported libtpu to match the compatibility window (12 we…
da7c90c4
dfm Temporarily skip some more linalg sharding checks.
70024d22
jakevdp doc: fix description of logsumexp axis
c56e794a
Google-ML-Automation Add a profiler test for gpu run
48a55a6d
jakevdp Improved errors when indexing with floats
b2c45b8e
Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
1f317663
Google-ML-Automation Chnages for kernel export
2a1eeb0c
Google-ML-Automation Update XLA dependency to use revision
e25caba9
yashk2810 `PRNGKeyArray.aval` should have the correct logical sharding. This re…
53494ade
Google-ML-Automation Update XLA dependency to use revision
9d6bcd63
dimitar-asenov [Mosaic GPU] Wire up the `slice_lengths` and `indices` operands in lo…
c60ef5a2
dimitar-asenov [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed…
3b305c66
pschuh Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
b8b690e5
Google-ML-Automation Update XLA dependency to use revision
eee4d601
gnecula Use the same name for aliased Vars when pretty-printing Jaxprs.
a6c47d6f
apaszke [Mosaic GPU][NFC] Clean up the computation of group strides
3038348f
Google-ML-Automation Merge pull request #26697 from gnecula:pp_aliased_var_names
bbadf990
apaszke [Mosaic GPU][NFC] Move the calculation of group strides into _validat…
11e6cfbc
cperivol [mgpu] Forach in tiled layout.
b9ebd918
bartchr808 #sdy support JAX export tests when Shardy is enabled.
ac493655
bartchr808 #sdy Add JAX backwards compatibility test.
ed4a7bba
apaszke [Mosaic GPU] Make the small WGMMA tile independent of transpose flags
e9f95cc3
Google-ML-Automation Fix convolution example (kernel should be OIHW, not IOHW).
1a57fdf7
WindQAQ [Mosaic] Rename dep name.
5179642e
carlosgmartin Fix linalg.norm to return zero for proper norms of empty matrices.
897e1a13
hawkinsp Fix wrong results in multidimensional pad.
7f05b74b
jakevdp doc: in lax.cond, note that both branches will be traced
84ca80d2
Google-ML-Automation Merge pull request #26895 from hawkinsp:pad2
439c412c
yashk2810 Remove the skip for test_output_streaming_inside_scan
07c4c03a
Google-ML-Automation Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_ma…
07d1cd02
Google-ML-Automation Merge pull request #26897 from jakevdp:cond-doc
4944dcb9
Google-ML-Automation Merge pull request #26862 from jakevdp:logsumexp-docs
f9f47217
Google-ML-Automation Merge pull request #26865 from jakevdp:fix-indexing-error
2c7043f6
sharadmv [Pallas] Add experimental (private for now) API for manual fusion int…
0b6c3550
sharadmv Add fuser to jax.experimental.pallas
d32e282f
sharadmv [Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
00d9f452
ayaka14732 Fix thread safety of JAX error checking
ea53c761
Google-ML-Automation Update XLA dependency to use revision
96dce0b6
dfm Update jnp.unique to support upstream interface changes.
6c5ef1a4
Google-ML-Automation Merge pull request #26883 from dfm:np-unique-sorted
8906f281
superbobry [pallas:triton] Emit a better error message for matmul with non-2D op…
155839bb
apaszke [Mosaic GPU] Make sure to do the async proxy fence before wargroup sync
cdae5fcf
Google-ML-Automation Remove redundant `BUILD_TAG` from `JAX` wheels build rule.
ce3412e5
charleshofer Merge branch 'rocm-main' into ci-upstream-sync-135_1
07cd809b
charleshofer charleshofer requested a review 314 days ago
charleshofer charleshofer merged d6cf755f into rocm-main 314 days ago
charleshofer charleshofer deleted the ci-upstream-sync-135_1 branch 314 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
No one assigned
Labels
Milestone