jax
Rocm jaxlib v0.9.0
#671
Open

Rocm jaxlib v0.9.0 #671

Ruturaj4 wants to merge 6583 commits into main from rocm-jaxlib-v0.9.0
Ruturaj4
boomanaiden154 Prefer isinstance(x, type) over type.isinstance
5006973b
Google-ML-Automation Merge pull request #34204 from jakevdp:doc-sidebar
54cf88d9
Google-ML-Automation Merge pull request #32268 from samanklesaria:issues/32267
227cb221
boomanaiden154 Prefer isinstance(x, type) over type.isinstance
27b3bffa
jakevdp Add jax.experimental.random to the wheel build
2486a5e3
mattjj fix shard_map transpose explicit sharding zero unsharding bug
bf11cd8a
mattjj add optional `explain` callback for weakref_lru_cache misses
f6b6edc1
Google-ML-Automation Merge pull request #34209 from jakevdp:fix-wheel
75fd86b0
Google-ML-Automation Merge pull request #34211 from mattjj:andy-customvjp-none
75130d4d
yashk2810 Handle ad.Zero cotangents in _reshard_transpose_fancy.
caaad6e0
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
fd57ef61
bchetioui [Pallas/Mosaic GPU] Enable more `WarpSpecializedPipelineWGTest`s.
1199fd7f
bchetioui [Mosaic GPU] Add a `bitwidth` field to `Relayout` constraints in layo…
de96f7b9
superbobry [pallas:sc] Allowed specifying tiling in `pltpu.emit_pipeline`
2a1a0a96
golechwierowicz [XLA:MGPU] Port Tiling to C++.
60a32d38
keshavb96 deviceless aot test
8c2555a8
emilyfertig Add a thread guard config option.
07bb0f6c
WindQAQ [Mosaic] Move Float8EXMYType to tpu.td.
1c3aa2ff
Google-ML-Automation Merge pull request #33542 from keshavb96:deviceless_aot_test
4b258967
ybaturina Remove redundant test targets that are already executed as a part of …
4e89ce42
mattjj respect self.statics in FlatTree.__eq__
9fad589b
boomanaiden154 Fix precommit breakage
b02123b9
mattjj sick
52790002
mattjj revive as many cache miss explanations as reasonably possible
06365873
superbobry [pallas:sc] Skip a few tests failing when the compiler uses tiled mem…
12dce9c0
Google-ML-Automation Colocated python: Use a wrapper when storing remote objects at the ba…
b10cee69
mattjj skip on jaxlib version
d6101d87
ybaturina Add `halt-for-connection` to `build_artifacts.yml` workflow call.
909e638a
Google-ML-Automation Merge pull request #33839 from jax-ml:pjit-without-linear-util
a8ca82a7
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
774597a4
bchetioui [Mosaic GPU] Use `isinstance(x, mlir_ty)` instead of the deprecated `…
5a50361d
beckerhe Remove unnecessary return from placement new in Mosaic GPU extension.
568bca12
superbobry [mosaic] Added a canonicalization rule for memref.dim of tpu.memref_s…
35ffd406
apaszke [Mosaic GPU] Add basic support for DSMEM
8947c9a2
apaszke [Mosaic GPU] Add support for using the redux instructions to speed up…
73ad4e5c
bchetioui [Pallas/Mosaic GPU] Disable `test_ragged_dot_transposed` temporarily.
e27b89b7
hawkinsp Remove some more stale version guards.
741ecce8
jakevdp [bug] fix grad-of-vmap-of-dynamic_slice with out-of-bound indices
0b4a4407
pschuh Allow max_size to be None (infinite).
c1f8ccfc
yashk2810 Don't pass axis_index_groups to prim.bind in psum batching rule if pr…
09c3f1c9
Google-ML-Automation Migrates `builder.create<Op>()` => `Op::create()`
d1918610
Google-ML-Automation Merge pull request #34229 from jakevdp:dynamic-slice-grad-fix
046b70ca
emilyfertig Fix Array `__format__` and `__str__` to handle McJAX arrays that are …
9528be82
mattjj [hijax] start implementing custom_vjp on top of hijax primitives
d557a060
Google-ML-Automation Merge pull request #34161 from mattjj:custom-vjp3-youandme
705171ab
hawkinsp Delete the jax_collectives_common_channel_id flag.
077574f8
jakevdp [indexing] implement strategy='dynamic_slice'
1447d898
Google-ML-Automation Merge pull request #34225 from jakevdp:index-dynamic-slice
eab73dbb
hawkinsp [TPU] Reenable a disabled test that now passes.
6acf46cc
Rifur13 [Pallas MGPU] Use `the cp.async.bulk` instruction for large contiguou…
2f3063c0
Rifur13 [Mosaic GPU] Expose the `mbarrier.complete_tx` instruction to manuall…
d800c3d6
rdyro Shorten splash attention kernel name to address https://github.com/ja…
22b3a0c2
mattjj [shmap] fix shmap error logic when subclasses of pspec are used
b2217c3b
pschuh Run the GetKeys inside GetOrCreate.
6744656b
Google-ML-Automation Merge pull request #34262 from mattjj:shmap-error-isleaf
603cef63
pschuh Run g4 fix on the weakref_lru_cache code.
e184e210
ybaturina Update rules_ml_toolchain version to accommodate custom redistributio…
c5e57367
rdyro Fix deserialization with specified layouts via a ShapeDtypeStruct bei…
a930721b
18praveenb Remove obsolete filegroup
11abf3bd
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
b00dc36f
mattjj don't warn on complex->real cast in dot transpose
c0fcb0b2
Google-ML-Automation Merge pull request #33708 from mattjj:issue33521
05d4abb0
mattjj [hijax] add vmap suport to CustomVJP hijax primitive
902cee57
Google-ML-Automation Merge pull request #34268 from mattjj:custom-vjp3-youandme-2
6f73035e
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
d40b4c95
oulgen [Mosaic] Add abs, sign, erf, atan2, reduce_min, reduce_prod support t…
447b0041
hyeontaek [JAX] Add IFRT SerDes to jaxlib deps
09748056
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
58154561
Google-ML-Automation Merge pull request #34120 from oulgen:mgpu-ops
1ef117f3
apaszke [Mosaic GPU] Test approximate math functions properly + eta reduce args
0ad11175
apaszke [Mosaic GPU] Enable redux.sync.f32 on Blackwell
098e953a
golechwierowicz [XLA:MGPU] Port Replicated wrapper to C++.
658e6505
superbobry [mosaic] infer-memref-layout now accepts target shape as a span
dcb011a1
allanrenucci [Pallas:MGPU] Lower `lax.sign` consistently for LANE and WG semantics.
e4fa0259
allanrenucci [Pallas:MGPU] Add Pallas lowering for `pl.debug_check` under WG seman…
a070482d
emilyfertig Add `thread_guard` to the public API.
1242ba32
allanrenucci [Mosaic GPU] Add support for f8 types for WGMMA with lhs in registers…
7cee4db6
Rifur13 [Pallas MGPU] Clip the size of contiguous TMA transfers to the size o…
27a568a3
Prakharprasun lax: ensure padtype_to_pads returns Python ints
6ce4314f
Matt-Hurd [jax.collect_profile] Allow arbitrary options to be passed to XProf
a4036e4a
jakevdp [indexing] support newaxis in static/dynamic slice strategies
aad0e01c
sharadmv Plumb prim params for call discharge rule (to handle named_computation_p
96787dd8
superbobry [sc] Generalized infer-memref-layout to support SC tiling
06d75423
sharadmv [Pallas/TPU] Don't lower eqns that have all dropvar outputs (aka DCE at
9e8fae66
superbobry [sc] Removed `infer_kernel_arguments` from infer-memref-layout
c984def2
Google-ML-Automation Merge pull request #33974 from Prakharprasun:fix-padtype-to-pads
a10149b6
sharadmv [Pallas] Allowlist semaphore/prng effects under remat and custom
c5e42402
levskaya Add lax.tile_p
c396a367
Google-ML-Automation Automated Code Change
bd07769f
tlongeri [Mosaic:TPU] Clean up tpu.memref_slice verifier
4691dfee
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
9a243d76
levskaya Add pallas lowering for jnp.tile using tpu.repeat.
f0ca449a
gnecula [export] Add support for explicit sharding.
96ecec4f
gnecula Handle the case when shardings are GSPMDSharding.
b7700e47
allanrenucci [Mosaic GPU] Support more reduction kinds and layouts in `MultiDimRed…
e5faf1d2
Google-ML-Automation [Pallas/interpreter] Add a prototype for a GPU kernel interpreter.
6eb71e90
allanrenucci [Pallas:MGPU][NFC] Update outdated docstring for `scratch_view`.
f25011f0
bchetioui [Mosaic GPU] Fix issue in `vector_dim` reduction with `vec_len > 2`.
678adf59
allanrenucci [Mosaic GPU][NFC] Replace usages of `_gpu_ops_gen` with `gpu` dialect…
35a14b8b
bchetioui [Mosaic GPU] Fix the return type of `_lift_fast_packed_instr` to matc…
ed12030c
gnecula Set the proper memory kinds for output in call_exported
4b76bd88
Google-ML-Automation Merge pull request #33597 from gnecula:export_shit
fe56c711
Google-ML-Automation Merge pull request #34290 from jakevdp:indexing-none
4fdb82cd
allanrenucci [Pallas:MGPU] Remove broadcast in scalar reduce test.
3e7d4dd5
ybaturina Change configurations for nightly TPU tests:
1a74daa6
mattjj [hijax] more custom_vjp support: remat and optimize_remat
4732c732
quoctruong Update the correct core count for TPU v7 runners
5267d8ad
ybaturina Fix numpy dependency.
1bbbe8c1
Google-ML-Automation Merge pull request #34320 from mattjj:custom-vjp3-youandme-3
940d2410
ybaturina Add "jax/_src/pallas/mosaic_gpu/interpret:interpret_pallas_call" to j…
99f65585
Rifur13 Revert switch to using `cp.async.bulk`.
fcbffe9c
JW1992 Breaking existing tests
0b7fc6a5
sharadmv [Pallas] Add delay effect
2a5f6379
emilyfertig Add "jax_force_dcn_cross_host_transfers" flag to use DCN instead of t…
4d76b711
sharadmv [Pallas TPU] Fix bug where in kernels that only contain inter-core
5f9c795f
yashk2810 Use pe._default_dce_rule if a prim doesn't have a `dce` method on it …
4f642fc0
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
b66cb459
allanrenucci [Mosaic GPU] Support signed/unsigned min reductions in `MultiDimReduc…
ab6c95cd
allanrenucci [Pallas:MGPU] Implement min/max scalar reductions under WG semantics.
7980031e
superbobry [pallas:mosaic] tpu.memref_{slice,squeeze} no longer support strided …
974972b2
danielsuo Update XLA commit to https://github.com/openxla/xla/commit/c358f68c2a…
e9d2deee
danielsuo [pmap] Suppress PmapSharding deprecation warning in multihost_utils_t…
12fb9d22
yashk2810 Don't use convert_element_type to cast in dot_general bwd pass. Inste…
b8c00a84
georgepaw [JAX] Populate Send/Recv frontend attributes.
4b9bbb5a
danielsuo [pmap] Remove the `jax_pmap_no_rank_reduction` config state.
de8e4e89
levskaya replace uses of pltpu.repeat with jnp.tile in pallas.
34a4cbce
danielsuo [pmap] Expand docs for migrating from pmap to shard_map.
9efab3cf
jakevdp Remove problematic warnings filter
697b18d9
jakevdp Add mapped_aval and unmapped_aval to jax.extend.core
ae582232
jakevdp [dep] Deprecate jax.numpy.fix
09684bc4
Google-ML-Automation Merge pull request #34310 from jakevdp:mapped-aval
69e2fa59
Google-ML-Automation Merge pull request #34293 from jakevdp:dep-fix
52ea5e89
allanrenucci [Pallas:MGPU] Enable `test_load_store_wgmma_transposed` under WG sema…
82183245
gnecula Reverts 0b7fc6a58ec36929ff5e98b46a359551839604eb
540671e0
jakevdp [indexing] fully support indexing & normalization modes within static…
e4e22fc5
allanrenucci [Pallas:MGPU] Simplify Mosaic GPU tests by removing intermediate SMEM…
20627f67
Google-ML-Automation Merge pull request #34327 from jakevdp:warning-filter
2e29fd98
Google-ML-Automation Merge pull request #34326 from jakevdp:indexing-normalize-indices
ddddd11b
Cjkkkk relax fp8 sdpa test tolerance
30e528ad
Google-ML-Automation [Pallas/interpreter] Add support for `run_scoped` in the GPU kernel i…
68e20333
sharadmv [Pallas] Fix broken pallas distributed test
a4befcdd
ZixuanJiang Remove support for some old format in `jax/tests/export_back_compat_t…
6bb594ab
Google-ML-Automation Reverts b8c00a84ddc1c463fcfe7b7bbdaf2eaa670886df
3422bfcc
pschuh Only reset preferred_element_type = x.aval.dtype if explicitly set
7cd6b1a9
emilyfertig Fix backwards compatibility with jaxlib.
267a9a58
tlongeri [NFC] Use getDefiningOp<Op>
a13f53eb
Google-ML-Automation Update XLA dependency to use revision http://github.com/openxla/xla/c…
643d2e4e
gnecula [export] Add backwards compatibility tests for v6.
7a7fd1eb
Google-ML-Automation Merge pull request #33853 from gnecula:export_v6_tests
bcc959b9
golechwierowicz [XLA:MGPU] Port TiledLayout's construction logic to C++.
9ac60a34
Google-ML-Automation [Mosaic GPU] Disable redux ops until properly benchmarked and optimized
ef84d389
golechwierowicz [XLA:MGPU] Port Partitioned.*Dims and VectorLength methods to C++.
92f08100
Google-ML-Automation Reverts 7cd6b1a9c711068f4f05312d24ef598fe40d4e50
15a3e1ed
gnecula [pallas] Make the LoweringDynamicShapeEnv use local mappings
602f212d
gnecula [export] Cleanup the handling of has_named_sharding
2df10cb3
Google-ML-Automation Merge pull request #34403 from gnecula:export_fix1
31ef45ef
Google-ML-Automation Merge pull request #34400 from gnecula:shape_poly_pallas_cache_reuse
d66c2a69
superbobry [pallas:sc] The lowering can now use an empty grid directly
0c804685
danielsuo [pmap] Deprecate setting `jax_pmap_shmap_merge` config state.
bb1017b8
ybaturina Update `rules_ml_toolchain` version to incorporate new GPU folders st…
a8927894
hawkinsp Add a version guard around TPU lowering rule changes.
e16ca653
ybaturina Add the test rule `compare_srcs_and_test_deps_test` that compares the…
8d97179e
allanrenucci [Pallas:MGPU] Fix scratch size for cross-warp reductions.
695eac8e
Google-ML-Automation Merge pull request #34375 from Cjkkkk:relax_fp8_sdpa_tolerance
c5831d16
mattjj [hijax] fix up custom_vjp3 error messages
002c3dc3
sharadmv [Pallas] Turn off tpu_7x configs for distributed test to avoid hangs
7f6fe293
sharadmv [Pallas] Remove PRNG/Semaphore effects from basic JAX allowlist since
ee85b82f
jakevdp [indexing] use strides=None for all unit strides
90076305
jakevdp [typ] annotate lax.padtype_to_pads
e64dccf8
Google-ML-Automation Merge pull request #34394 from mattjj:custom-vjp3-youandme-4
429fcfc0
mwhittaker Uninstall xprof on python3.13-nogil always.
fe7676d4
rdyro [Pallas] Disable non-sharded multi-chip splash attention test config,…
9e481477
jakevdp [hijax] fix shard_map of hijax primitive
2ecafeee
danielsuo [pmap] Suppress PmapSharding deprecation warning in (more) multihost_…
3d400948
Google-ML-Automation Merge pull request #34414 from jakevdp:padtype-to-pads
1143012a
Google-ML-Automation Merge pull request #34410 from jakevdp:indexing-strides
f0ff1b85
belitskiy Migrate TSAN workflow to use RBE.
acdd012a
loislo Follow up pr after triton integration cl and tokamax pr.
91223b93
Google-ML-Automation Merge pull request #34418 from jakevdp:hijax-shard-map
6a9f3d98
Google-ML-Automation [Pallas] Create trace_value primitive for dynamic value logging
02939fef
ybaturina Add `testonly=True` to py_import targets that depend on testonly `whe…
a84ca0c5
mwhittaker Bump XLA version.
cdfe904a
mwhittaker Prepare for JAX release 0.9.0
dd745d79
belitskiy Increase shard count for TPU and GPU tests back to 5 for api_test.py.
8f4389a5
ybaturina Fix wheel sources tests for Windows platform.
fae5defd
emilyfertig Disable `tests/multiprocess:socket_transfer_test` since it's failing …
d97eed6e
emilyfertig Re-enable `socket_transfer_test` internally.
18b34a12
mwhittaker Add libtpu date guard for failing test.
9e7b005e
mwhittaker Skip tpu_pallas_distributed_test on 7x.
77d9ffb0
mwhittaker Add libtpu guard to failing tpu_trace_value_test.
80b1ef6a
mwhittaker Disable `tpu_splash_attention_kernel_test` on TPU v7x.
28799d51
18praveenb Remove failing test_itof_dot_canonicalization_fails_without_compat_mo…
98d6b4ba
yashk2810 Use maxsize=None with trace_to_jaxpr's weakref_lru_cache to get more …
2de5b8b0
charleshofer Remove nvidia_wheel_versions
c706c2fa
charleshofer Make jaxlib targets visible
de91c59a
charleshofer hipblas typedef fix
6e0eb3ec
charleshofer No GPU fail
4f3741b3
AratiGanesh Wrap HIP inline functions in anonymous namespaces in vendor.h
2b08a20d
dsicarov-amd SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError
224b5baf
charleshofer Add shared utility function get_rocm_version to test_util.py
1bac5e70
phambinhfin Fix hipSparse CSR algorithm mappings for ROCm 7
81cdeb49
phambinhfin Make nvidia version data optional for ROCm builds
d88ca192
phambinhfin Fix v_pages quantization and adjust test params for ROCm compatibilit…
61d28322
Arech8 Address LLVM assertion failure due to a multithreaded use. Update .gi…
8b0b3d10
Arech8 Add skip of test_is_finite() on Cuda (#565)
6ab35026
AratiGanesh Add rocm test requirements file (#570)
835856e0
charleshofer Let the unit tests use build.py for setting up Bazel commands for uni…
6bd1a15e
gulsumgudukbay adding abort logic to rocm/jax (#590)
a83cc9e9
phambinhfin Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (…
17c7560c
phambinhfin Fix shared memory limit check for ROCm in test_dot (#596)
e5b66138
magaonka-amd Fix Numpy signatures test (#598)
1a823af2
tsrw2048 Fix GPU lowering rule for SVD on ROCm devices (#600)
fd0e42c9
tsrw2048 Fixed merge conflicts when moving GESVDJ commit from JAX 0.8.0 to 0.8…
aaac58cb
Ruturaj4 fix merge arts
e8ac89ba
tsrw2048 Enabled testTridiagonal to run on ROCm devices (#607)
ee2c254e
tsrw2048 Enabled ToeplitzSymmetricConstruction unit tests for ROCm devices (#608)
455f6b10
phambinhfin Enable testSvdSubsetByIndex for ROCm with subset_by_index skip (#603)
fc3bc807
magaonka-amd Fix KeyError for bytes_reservable_limit on ROCm
0956a239
magaonka-amd Backport cuda_array_interface , testDotAlgorithm, Optimizer test fixe…
ca932346
gulsumgudukbay Enable RngShardingTests (#644)
4a075b9a
magaonka-amd Enable reduce_window tests on ROCm (#643)
2557ff08
magaonka-amd Enable test_variadic_reduce_window on ROCm (#647)
134825e3
gulsumgudukbay Unskip supported dtypes for testConvolutionsPreferredElementType (#649)
ee5581c8
tsrw2048 Enabled test for condition number on ROCm devices. (#613)
10c7e75a
tsrw2048 Enabled RNN unit test: test_no_workspace_overflow for ROCm devices (#…
708d732f
tsrw2048 Added changes from PR #626 and PR #645. This also fixes merge conflic…
98cc66ca
Ruturaj4 Ruturaj4 requested a review 26 days ago
Ruturaj4 [Pallas] Fix ROCm GPU architecture detection and route to Triton backend
b77014fd
magaonka-amd Enable array interoperability tests on ROCm platform (#660)
91be368b
magaonka-amd Skip sparse tests on ROCm due to hipSPARSE issue (#652)
95e1df00
magaonka-amd Update sparse test skip messages in v0.8.2 (#653)
e0739e6c
tsrw2048 Port skip for "test_prim_tridiagonal_solve" tests from JAX v0.8.0 to …
a988aebf
magaonka-amd Fix test_cuda_array_interface test skip condition (#657)
34676201
AratiGanesh Enable testMultivariateNormalSingularCovariance on ROCm (#666)
ed14fc6b
AratiGanesh Skip test_batch_axis_sharding_jvp because of hipSPARSE issue (#667)
fc10735b
AratiGanesh Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors…
93f5fdc1
gulsumgudukbay Update Skip Reason Outputs (#663)
3747d88a
magaonka-amd Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677)
103da597
magaonka-amd Enable lobpcg tests on ROCm platform (#681)
d3cce970
magaonka-amd Enable lax backend scipy tests on ROCm GPUs (#687)
498f7350
AratiGanesh Add ROCm encoding for test_struct_encoding_determinism (#683)
6701f5d1
magaonka-amd Remove 'mean' from unsupported params for jnp.var (#689)
4803ba8d
magaonka-amd Enable memory space export tests on ROCm GPUs (#690)
2c7ef615
magaonka-amd Implement approx_tanh for ROCm using OCML tanh function (#691)
7071f715
magaonka-amd Enable test deviceless aot compile test on ROCm (#694)
76c5165a
AratiGanesh Skipping testEighTinyNorm due to hipSolver issues (#697)
6df5f178
tsrw2048 Modified memory space export test to run on ROCm (for some tests). (#…
ac4ba1cb
tsrw2048 Add `device test` unit tests for ROCm (JAX v0.9.0) (#705)
0988e6da
AratiGanesh Skip test_tridiagonal_solve_grad test 0.9.0 (#703)
57c00805
AratiGanesh Skip test_batch_axis_sharding_jvp13 test 0.9.0 (#709)
8cb19cfc
phambinhfin Update skip message version from 0.8.0 to 0.9.0 for test_is_finite on…
32ba3e95

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
No one assigned
Labels
Milestone