Prefer isinstance(x, type) over type.isinstance
5006973b
Merge pull request #34204 from jakevdp:doc-sidebar
54cf88d9
Merge pull request #32268 from samanklesaria:issues/32267
227cb221
Prefer isinstance(x, type) over type.isinstance
27b3bffa
Add jax.experimental.random to the wheel build
2486a5e3
fix shard_map transpose explicit sharding zero unsharding bug
bf11cd8a
add optional `explain` callback for weakref_lru_cache misses
f6b6edc1
Merge pull request #34209 from jakevdp:fix-wheel
75fd86b0
Merge pull request #34211 from mattjj:andy-customvjp-none
75130d4d
Handle ad.Zero cotangents in _reshard_transpose_fancy.
caaad6e0
Update XLA dependency to use revision http://github.com/openxla/xla/c…
fd57ef61
[Pallas/Mosaic GPU] Enable more `WarpSpecializedPipelineWGTest`s.
1199fd7f
[Mosaic GPU] Add a `bitwidth` field to `Relayout` constraints in layo…
de96f7b9
[pallas:sc] Allowed specifying tiling in `pltpu.emit_pipeline`
2a1a0a96
[XLA:MGPU] Port Tiling to C++.
60a32d38
deviceless aot test
8c2555a8
Add a thread guard config option.
07bb0f6c
[Mosaic] Move Float8EXMYType to tpu.td.
1c3aa2ff
Merge pull request #33542 from keshavb96:deviceless_aot_test
4b258967
Remove redundant test targets that are already executed as a part of …
4e89ce42
respect self.statics in FlatTree.__eq__
9fad589b
Fix precommit breakage
b02123b9
sick
52790002
revive as many cache miss explanations as reasonably possible
06365873
[pallas:sc] Skip a few tests failing when the compiler uses tiled mem…
12dce9c0
Colocated python: Use a wrapper when storing remote objects at the ba…
b10cee69
skip on jaxlib version
d6101d87
Add `halt-for-connection` to `build_artifacts.yml` workflow call.
909e638a
Merge pull request #33839 from jax-ml:pjit-without-linear-util
a8ca82a7
Update XLA dependency to use revision http://github.com/openxla/xla/c…
774597a4
[Mosaic GPU] Use `isinstance(x, mlir_ty)` instead of the deprecated `…
5a50361d
Remove unnecessary return from placement new in Mosaic GPU extension.
568bca12
[mosaic] Added a canonicalization rule for memref.dim of tpu.memref_s…
35ffd406
[Mosaic GPU] Add basic support for DSMEM
8947c9a2
[Mosaic GPU] Add support for using the redux instructions to speed up…
73ad4e5c
[Pallas/Mosaic GPU] Disable `test_ragged_dot_transposed` temporarily.
e27b89b7
Remove some more stale version guards.
741ecce8
[bug] fix grad-of-vmap-of-dynamic_slice with out-of-bound indices
0b4a4407
Allow max_size to be None (infinite).
c1f8ccfc
Don't pass axis_index_groups to prim.bind in psum batching rule if pr…
09c3f1c9
Migrates `builder.create<Op>()` => `Op::create()`
d1918610
Merge pull request #34229 from jakevdp:dynamic-slice-grad-fix
046b70ca
Fix Array `__format__` and `__str__` to handle McJAX arrays that are …
9528be82
[hijax] start implementing custom_vjp on top of hijax primitives
d557a060
Merge pull request #34161 from mattjj:custom-vjp3-youandme
705171ab
Delete the jax_collectives_common_channel_id flag.
077574f8
[indexing] implement strategy='dynamic_slice'
1447d898
Merge pull request #34225 from jakevdp:index-dynamic-slice
eab73dbb
[TPU] Reenable a disabled test that now passes.
6acf46cc
[Pallas MGPU] Use `the cp.async.bulk` instruction for large contiguou…
2f3063c0
[Mosaic GPU] Expose the `mbarrier.complete_tx` instruction to manuall…
d800c3d6
Shorten splash attention kernel name to address https://github.com/ja…
22b3a0c2
[shmap] fix shmap error logic when subclasses of pspec are used
b2217c3b
Run the GetKeys inside GetOrCreate.
6744656b
Merge pull request #34262 from mattjj:shmap-error-isleaf
603cef63
Run g4 fix on the weakref_lru_cache code.
e184e210
Update rules_ml_toolchain version to accommodate custom redistributio…
c5e57367
Fix deserialization with specified layouts via a ShapeDtypeStruct bei…
a930721b
Remove obsolete filegroup
11abf3bd
Update XLA dependency to use revision http://github.com/openxla/xla/c…
b00dc36f
don't warn on complex->real cast in dot transpose
c0fcb0b2
Merge pull request #33708 from mattjj:issue33521
05d4abb0
[hijax] add vmap suport to CustomVJP hijax primitive
902cee57
Merge pull request #34268 from mattjj:custom-vjp3-youandme-2
6f73035e
Update XLA dependency to use revision http://github.com/openxla/xla/c…
d40b4c95
[Mosaic] Add abs, sign, erf, atan2, reduce_min, reduce_prod support t…
447b0041
[JAX] Add IFRT SerDes to jaxlib deps
09748056
Update XLA dependency to use revision http://github.com/openxla/xla/c…
58154561
Merge pull request #34120 from oulgen:mgpu-ops
1ef117f3
[Mosaic GPU] Test approximate math functions properly + eta reduce args
0ad11175
[Mosaic GPU] Enable redux.sync.f32 on Blackwell
098e953a
[XLA:MGPU] Port Replicated wrapper to C++.
658e6505
[mosaic] infer-memref-layout now accepts target shape as a span
dcb011a1
[Pallas:MGPU] Lower `lax.sign` consistently for LANE and WG semantics.
e4fa0259
[Pallas:MGPU] Add Pallas lowering for `pl.debug_check` under WG seman…
a070482d
Add `thread_guard` to the public API.
1242ba32
[Mosaic GPU] Add support for f8 types for WGMMA with lhs in registers…
7cee4db6
[Pallas MGPU] Clip the size of contiguous TMA transfers to the size o…
27a568a3
lax: ensure padtype_to_pads returns Python ints
6ce4314f
[jax.collect_profile] Allow arbitrary options to be passed to XProf
a4036e4a
[indexing] support newaxis in static/dynamic slice strategies
aad0e01c
Plumb prim params for call discharge rule (to handle named_computation_p
96787dd8
[sc] Generalized infer-memref-layout to support SC tiling
06d75423
[Pallas/TPU] Don't lower eqns that have all dropvar outputs (aka DCE at
9e8fae66
[sc] Removed `infer_kernel_arguments` from infer-memref-layout
c984def2
Merge pull request #33974 from Prakharprasun:fix-padtype-to-pads
a10149b6
[Pallas] Allowlist semaphore/prng effects under remat and custom
c5e42402
Add lax.tile_p
c396a367
Automated Code Change
bd07769f
[Mosaic:TPU] Clean up tpu.memref_slice verifier
4691dfee
Update XLA dependency to use revision http://github.com/openxla/xla/c…
9a243d76
Add pallas lowering for jnp.tile using tpu.repeat.
f0ca449a
[export] Add support for explicit sharding.
96ecec4f
Handle the case when shardings are GSPMDSharding.
b7700e47
[Mosaic GPU] Support more reduction kinds and layouts in `MultiDimRed…
e5faf1d2
[Pallas/interpreter] Add a prototype for a GPU kernel interpreter.
6eb71e90
[Pallas:MGPU][NFC] Update outdated docstring for `scratch_view`.
f25011f0
[Mosaic GPU] Fix issue in `vector_dim` reduction with `vec_len > 2`.
678adf59
[Mosaic GPU][NFC] Replace usages of `_gpu_ops_gen` with `gpu` dialect…
35a14b8b
[Mosaic GPU] Fix the return type of `_lift_fast_packed_instr` to matc…
ed12030c
Set the proper memory kinds for output in call_exported
4b76bd88
Merge pull request #33597 from gnecula:export_shit
fe56c711
Merge pull request #34290 from jakevdp:indexing-none
4fdb82cd
[Pallas:MGPU] Remove broadcast in scalar reduce test.
3e7d4dd5
Change configurations for nightly TPU tests:
1a74daa6
[hijax] more custom_vjp support: remat and optimize_remat
4732c732
Update the correct core count for TPU v7 runners
5267d8ad
Fix numpy dependency.
1bbbe8c1
Merge pull request #34320 from mattjj:custom-vjp3-youandme-3
940d2410
Add "jax/_src/pallas/mosaic_gpu/interpret:interpret_pallas_call" to j…
99f65585
Revert switch to using `cp.async.bulk`.
fcbffe9c
Breaking existing tests
0b7fc6a5
[Pallas] Add delay effect
2a5f6379
Add "jax_force_dcn_cross_host_transfers" flag to use DCN instead of t…
4d76b711
[Pallas TPU] Fix bug where in kernels that only contain inter-core
5f9c795f
Use pe._default_dce_rule if a prim doesn't have a `dce` method on it …
4f642fc0
Update XLA dependency to use revision http://github.com/openxla/xla/c…
b66cb459
[Mosaic GPU] Support signed/unsigned min reductions in `MultiDimReduc…
ab6c95cd
[Pallas:MGPU] Implement min/max scalar reductions under WG semantics.
7980031e
[pallas:mosaic] tpu.memref_{slice,squeeze} no longer support strided …
974972b2
Update XLA commit to https://github.com/openxla/xla/commit/c358f68c2a…
e9d2deee
[pmap] Suppress PmapSharding deprecation warning in multihost_utils_t…
12fb9d22
Don't use convert_element_type to cast in dot_general bwd pass. Inste…
b8c00a84
[JAX] Populate Send/Recv frontend attributes.
4b9bbb5a
[pmap] Remove the `jax_pmap_no_rank_reduction` config state.
de8e4e89
replace uses of pltpu.repeat with jnp.tile in pallas.
34a4cbce
[pmap] Expand docs for migrating from pmap to shard_map.
9efab3cf
Remove problematic warnings filter
697b18d9
Add mapped_aval and unmapped_aval to jax.extend.core
ae582232
[dep] Deprecate jax.numpy.fix
09684bc4
Merge pull request #34310 from jakevdp:mapped-aval
69e2fa59
Merge pull request #34293 from jakevdp:dep-fix
52ea5e89
[Pallas:MGPU] Enable `test_load_store_wgmma_transposed` under WG sema…
82183245
Reverts 0b7fc6a58ec36929ff5e98b46a359551839604eb
540671e0
[indexing] fully support indexing & normalization modes within static…
e4e22fc5
[Pallas:MGPU] Simplify Mosaic GPU tests by removing intermediate SMEM…
20627f67
Merge pull request #34327 from jakevdp:warning-filter
2e29fd98
Merge pull request #34326 from jakevdp:indexing-normalize-indices
ddddd11b
relax fp8 sdpa test tolerance
30e528ad
[Pallas/interpreter] Add support for `run_scoped` in the GPU kernel i…
68e20333
[Pallas] Fix broken pallas distributed test
a4befcdd
Remove support for some old format in `jax/tests/export_back_compat_t…
6bb594ab
Reverts b8c00a84ddc1c463fcfe7b7bbdaf2eaa670886df
3422bfcc
Only reset preferred_element_type = x.aval.dtype if explicitly set
7cd6b1a9
Fix backwards compatibility with jaxlib.
267a9a58
[NFC] Use getDefiningOp<Op>
a13f53eb
Update XLA dependency to use revision http://github.com/openxla/xla/c…
643d2e4e
[export] Add backwards compatibility tests for v6.
7a7fd1eb
Merge pull request #33853 from gnecula:export_v6_tests
bcc959b9
[XLA:MGPU] Port TiledLayout's construction logic to C++.
9ac60a34
[Mosaic GPU] Disable redux ops until properly benchmarked and optimized
ef84d389
[XLA:MGPU] Port Partitioned.*Dims and VectorLength methods to C++.
92f08100
Reverts 7cd6b1a9c711068f4f05312d24ef598fe40d4e50
15a3e1ed
[pallas] Make the LoweringDynamicShapeEnv use local mappings
602f212d
[export] Cleanup the handling of has_named_sharding
2df10cb3
Merge pull request #34403 from gnecula:export_fix1
31ef45ef
Merge pull request #34400 from gnecula:shape_poly_pallas_cache_reuse
d66c2a69
[pallas:sc] The lowering can now use an empty grid directly
0c804685
[pmap] Deprecate setting `jax_pmap_shmap_merge` config state.
bb1017b8
Update `rules_ml_toolchain` version to incorporate new GPU folders st…
a8927894
Add a version guard around TPU lowering rule changes.
e16ca653
Add the test rule `compare_srcs_and_test_deps_test` that compares the…
8d97179e
[Pallas:MGPU] Fix scratch size for cross-warp reductions.
695eac8e
Merge pull request #34375 from Cjkkkk:relax_fp8_sdpa_tolerance
c5831d16
[hijax] fix up custom_vjp3 error messages
002c3dc3
[Pallas] Turn off tpu_7x configs for distributed test to avoid hangs
7f6fe293
[Pallas] Remove PRNG/Semaphore effects from basic JAX allowlist since
ee85b82f
[indexing] use strides=None for all unit strides
90076305
[typ] annotate lax.padtype_to_pads
e64dccf8
Merge pull request #34394 from mattjj:custom-vjp3-youandme-4
429fcfc0
Uninstall xprof on python3.13-nogil always.
fe7676d4
[Pallas] Disable non-sharded multi-chip splash attention test config,…
9e481477
[hijax] fix shard_map of hijax primitive
2ecafeee
[pmap] Suppress PmapSharding deprecation warning in (more) multihost_…
3d400948
Merge pull request #34414 from jakevdp:padtype-to-pads
1143012a
Merge pull request #34410 from jakevdp:indexing-strides
f0ff1b85
Migrate TSAN workflow to use RBE.
acdd012a
Follow up pr after triton integration cl and tokamax pr.
91223b93
Merge pull request #34418 from jakevdp:hijax-shard-map
6a9f3d98
[Pallas] Create trace_value primitive for dynamic value logging
02939fef
Add `testonly=True` to py_import targets that depend on testonly `whe…
a84ca0c5
Bump XLA version.
cdfe904a
Prepare for JAX release 0.9.0
dd745d79
Increase shard count for TPU and GPU tests back to 5 for api_test.py.
8f4389a5
Fix wheel sources tests for Windows platform.
fae5defd
Disable `tests/multiprocess:socket_transfer_test` since it's failing …
d97eed6e
Re-enable `socket_transfer_test` internally.
18b34a12
Add libtpu date guard for failing test.
9e7b005e
Skip tpu_pallas_distributed_test on 7x.
77d9ffb0
Add libtpu guard to failing tpu_trace_value_test.
80b1ef6a
Disable `tpu_splash_attention_kernel_test` on TPU v7x.
28799d51
Remove failing test_itof_dot_canonicalization_fails_without_compat_mo…
98d6b4ba
Use maxsize=None with trace_to_jaxpr's weakref_lru_cache to get more …
2de5b8b0
Remove nvidia_wheel_versions
c706c2fa
Make jaxlib targets visible
de91c59a
hipblas typedef fix
6e0eb3ec
No GPU fail
4f3741b3
Wrap HIP inline functions in anonymous namespaces in vendor.h
2b08a20d
SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError
224b5baf
Add shared utility function get_rocm_version to test_util.py
1bac5e70
Fix hipSparse CSR algorithm mappings for ROCm 7
81cdeb49
Make nvidia version data optional for ROCm builds
d88ca192
Fix v_pages quantization and adjust test params for ROCm compatibilit…
61d28322
Address LLVM assertion failure due to a multithreaded use. Update .gi…
8b0b3d10
Add skip of test_is_finite() on Cuda (#565)
6ab35026
Add rocm test requirements file (#570)
835856e0
Let the unit tests use build.py for setting up Bazel commands for uni…
6bd1a15e
adding abort logic to rocm/jax (#590)
a83cc9e9
Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (…
17c7560c
Fix shared memory limit check for ROCm in test_dot (#596)
e5b66138
Fix Numpy signatures test (#598)
1a823af2
Fix GPU lowering rule for SVD on ROCm devices (#600)
fd0e42c9
Fixed merge conflicts when moving GESVDJ commit from JAX 0.8.0 to 0.8…
aaac58cb
fix merge arts
e8ac89ba
Enabled testTridiagonal to run on ROCm devices (#607)
ee2c254e
Enabled ToeplitzSymmetricConstruction unit tests for ROCm devices (#608)
455f6b10
Enable testSvdSubsetByIndex for ROCm with subset_by_index skip (#603)
fc3bc807
Fix KeyError for bytes_reservable_limit on ROCm
0956a239
Backport cuda_array_interface , testDotAlgorithm, Optimizer test fixe…
ca932346
Enable RngShardingTests (#644)
4a075b9a
Enable reduce_window tests on ROCm (#643)
2557ff08
Enable test_variadic_reduce_window on ROCm (#647)
134825e3
Unskip supported dtypes for testConvolutionsPreferredElementType (#649)
ee5581c8
Enabled test for condition number on ROCm devices. (#613)
10c7e75a
Enabled RNN unit test: test_no_workspace_overflow for ROCm devices (#…
708d732f
Added changes from PR #626 and PR #645. This also fixes merge conflic…
98cc66ca
[Pallas] Fix ROCm GPU architecture detection and route to Triton backend
b77014fd
Enable array interoperability tests on ROCm platform (#660)
91be368b
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
95e1df00
Update sparse test skip messages in v0.8.2 (#653)
e0739e6c
Port skip for "test_prim_tridiagonal_solve" tests from JAX v0.8.0 to …
a988aebf
Fix test_cuda_array_interface test skip condition (#657)
34676201
Enable testMultivariateNormalSingularCovariance on ROCm (#666)
ed14fc6b
Skip test_batch_axis_sharding_jvp because of hipSPARSE issue (#667)
fc10735b
Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors…
93f5fdc1
Update Skip Reason Outputs (#663)
3747d88a
Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677)
103da597
Enable lobpcg tests on ROCm platform (#681)
d3cce970
Enable lax backend scipy tests on ROCm GPUs (#687)
498f7350
Add ROCm encoding for test_struct_encoding_determinism (#683)
6701f5d1
Remove 'mean' from unsupported params for jnp.var (#689)
4803ba8d
Enable memory space export tests on ROCm GPUs (#690)
2c7ef615
Implement approx_tanh for ROCm using OCML tanh function (#691)
7071f715
Enable test deviceless aot compile test on ROCm (#694)
76c5165a
Skipping testEighTinyNorm due to hipSolver issues (#697)
6df5f178
Modified memory space export test to run on ROCm (for some tests). (#…
ac4ba1cb
Add `device test` unit tests for ROCm (JAX v0.9.0) (#705)
0988e6da
Skip test_tridiagonal_solve_grad test 0.9.0 (#703)
57c00805
Skip test_batch_axis_sharding_jvp13 test 0.9.0 (#709)
8cb19cfc
Update skip message version from 0.8.0 to 0.9.0 for test_is_finite on…
32ba3e95
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub