jax
CI: 04/30/25 upstream sync
#392
Merged

CI: 04/30/25 upstream sync #392

charleshofer merged 900 commits into rocm-main from ci-upstream-sync-183_1
rocm-repo-management-api-2
rocm-repo-management-api-2 rocm-repo-management-api-2 requested a review 240 days ago
rocm-repo-management-api-2[bot] rocm-repo-management-api-2[bot] enabled auto-merge (rebase) 240 days ago
charleshofer charleshofer force pushed from 806190d0 to 35fa98b3 240 days ago
disabled auto-merge 240 days ago
Manually disabled by user
charleshofer charleshofer enabled auto-merge (rebase) 240 days ago
charleshofer
charleshofer approved these changes on 2025-04-30
disabled auto-merge 240 days ago
Rebase failed
hawkinsp JAX release v0.6.0
9e9a3f03
hawkinsp Update version numbers after release.
f7e3fe62
yashk2810 Remove _manual_axes from NamedSharding since we can now track the man…
46fb4001
yashk2810 Remove jax_varying_axes_in_types config and `rewrite` from `shard_map_p`
85eb58a1
Google-ML-Automation Use llvm::cast/dyn_cast/isa since alternatives are deprecated in http…
4254524f
superbobry Added `lax.axis_size` and switched all existing usage of `psum(1, ...…
f5101ef7
superbobry Do not use `-> ...`
2bf0ca19
Google-ML-Automation Automated Code Change
fbfaa0d1
vfdev-5 Suppress race between split_keys_entry_added and dict_dict_merge in 3…
b55a4e94
Google-ML-Automation Update XLA dependency to use revision
1770096e
yashk2810 Remove shard_map rewrite machinery since we have `vma` in types by de…
1fec3941
yashk2810 Make `ShardMapTracer` track `vma` instead of `rep`.
0ca67825
nitins17 Exclude running tests against the oldest supported libtpu for releases
7f709323
yashk2810 Use _make_lengths_same for explicit mode too.
4f5ed5a2
justinjfu [Pallas] Fix potential race condition in Pallas TPU docs
43287e37
Google-ML-Automation Enable cross-compilation builds of the wheels.
8d8385ad
pschuh Remove unused jax_spmd_mode flag.
c9024ae0
superbobry [pallas:mosaic] Replaced `device_type=` with `kernel_type` in `TPUCom…
7774297a
charleshofer Make Clang use manylinux C++ standard library
9fbbffb3
zacmustin Remove roofline subtests so they work with pytest.
06251b38
dfm Don't recompute source_info for each tracer during staging.
e6ac43a0
Google-ML-Automation Return empty `wheel sources map` in case if `wheel_sources` is `None`.
e2e2bd1e
hawkinsp Remove code to support jaxlib < v0.6.
d100df27
yashk2810 Enter into auto mode for `.at[...].get(...)` a bit earlier so that al…
ffae4a29
yashk2810 Rename `with_user_mesh` to `with_explicit_mesh`
b93f0282
sharadmv [Pallas] Generalize BlockSpec to support different indexing mode for …
9558ca0d
tlongeri [Mosaic TPU] Enable non-sublane-aligned 2D int8 load/stores
cb04ff56
Google-ML-Automation Update XLA dependency to use revision
6a811093
hawkinsp Remove unused type forwarding declarations for jax.lib.xla_client.{Sh…
98259c61
dfm Reverts c2ba1790417ca206a4d88b25aef4d5ae510dd717
f66d71e4
hawkinsp Allow the CPU collective implementation to be overridden to None.
5f57c666
hawkinsp [JAX] Remove jax.lib.xla_client.mlir_api_version and its uses.
59b155ac
hawkinsp Don't use the deprecated jax.dlpack.to_dlpack in tests.
ea3557c9
yashk2810 Drop into `Auto` mode for `.at[...].set(...)` but instead of taking a…
4fb72216
hawkinsp Mark CliDebuggerTest as a thread-unsafe test.
1f865a0b
dfm Ensure outputs are tracers when inlining jit.
a4d2e28e
ZixuanJiang #sdy Do not close any partially sharded dimensions if using auto axes…
3d2c2dd1
yashk2810 Handle `sharding` param in convert_element_type's batching rule prope…
30e3809d
Google-ML-Automation Makes Effort_02 the default value for memory_fitting_level.
67c4e75c
rdyro Refactor array serialization into separate JAX and tensorstore logic
b8c3542d
gspschmid Add jax.fwd_and_bwd
1144cb03
yashk2810 Update changelog to add information about breaking change with respec…
7934a4ba
Google-ML-Automation Make `Shape::add_dimensions()` validate arguments by default.
0b8b0e2f
Google-ML-Automation Update XLA dependency to use revision
76124057
rdyro fix BUILD file syntax error
3d918f20
Google-ML-Automation [Pallas/Fuser] Add support for pl.Element in fuser BlockSpec
d587b76a
Google-ML-Automation Update XLA dependency to use revision
29d743e6
Google-ML-Automation Add support for TPU7x in Mosaic.
87f086e5
Google-ML-Automation Update XLA dependency to use revision
ca9f3d93
emergenz fix: typo
a36a8f71
vfdev-5 Fixed cached mlir module mutation issue in export
410d56e5
jakevdp jax.random.bernoulli: add mode='high' for improved sampling for small p
c187b5cc
jonbarron Simplify (and potentially accelerate) fftfreq(), as a single vectoriz…
cf56cb80
yashk2810 Allow all reshapes if the operand is fully replicated
4a4a0f05
Caslyn [docs] fix typo in doc link
ab2516db
cgarciae implement _replace_with for PRNGKeyArray
844187ff
Google-ML-Automation Switch JAX to the new `ProgramShape::AddParameter()` API.
1193da06
bythew3i [ragged-paged-attn] Add auto-tuned table that is being used for vLLM
40eb37f9
Google-ML-Automation [Pallas][jax] Better error message for unexpected types in standard a…
f24595cd
hyeontaek [JAX] Use `xla::ifrt::Client::MakeArraysFromHostBufferShards()` in Ar…
7688ee7e
andportnoy [Mosaic GPU] Align on device profiler array in smem to 8 bytes
374a5edd
bartchr808 #sdy Add more debug info when there is a mesh mismatch in JAX export
42811dcb
superbobry [mosaic_gpu] Fixed the export code path in `_mosaic_gpu_lowering_rule`
b877daaa
Google-ML-Automation Update XLA dependency to use revision
8c675bca
jakevdp jnp.frexp: add custom JVP rule for proper derivatives.
fc3dba80
hawkinsp Add support for generating freethreaded pip lockfiles.
18a5f05f
jburnim [Pallas] Support the new TPU interpret mode with pl.core_map.
af4cd548
dfm Fix ensure_compile_time_eval context after stackless.
81dd0b47
scottstanie BUG: return one covariance matrix per `rhs` in `polyfit`
1cfa3894
jakevdp typing: improve annotations for scatter implementations
8347142b
justinjfu [Pallas][Mosaic TPU] Improve error message for 1D block specs when th…
445fb63b
MichaelHudgins [CI] Add tpu v6e-8 to nightly test and release test.
6809c690
justinjfu [Pallas] Update TPU pipelining docs
aca090d7
sharadmv [Pallas Fuser] Allow multiple BlockSpec inputs to select_n push rule …
56e12fd6
dfm Fix handling of SymbolicZero output when batching custom_jvp.
26691f39
dimitar-asenov [Mosaic GPU] Do not create a barrier with `arrival_count == 0` in the…
3a5504e5
sharadmv [Pallas Fuser] Change physicalize to resolve_fusion_dtypes
755265d5
jburnim [Pallas] Propagate Jaxpr effects through pl.fusible
f060fcc6
Google-ML-Automation Set core_index to default for tpu_pallas_async_test
ebb21ee5
jakevdp jnp.ldexp: avoid overflow for large exponent
541c1169
hawkinsp Update tsan requirements patch after lockfile update.
0dccb0f5
yashk2810 Initial commit of non-experimental `shard_map`.
d3e19fea
bythew3i [ragged-paged-attn] Autotune set up and increase page size to avoid S…
6321af49
jakevdp DOC: link to ai-stack tutorials from JAX's front page
e0295f7d
hanrach9 Skip unary_ops_accuracy test for TPU version 7 and above.
36a6c12b
tlongeri [Mosaic:TPU][Relayout] Row shifts for packed types and non-native til…
06390001
charleshofer Account for versioned clang binaries
5253e99f
yashk2810 Expose `jax.shard_map` as the public non-experimental API and move sh…
330f414a
justinjfu [Pallas] Make Pallas PRNG more robust by improving the Mosaic seed br…
d35c72df
yashk2810 Make `set_mesh` thread local so that it behaves exactly like `use_mesh`
fbbbb7b6
superbobry [pallas:mosaic_gpu] Added an export backward compatibility test for a…
069409c3
rajasekharporeddy Replace reference to jax.readthedocs.io with docs.jax.dev in jax._src…
d62d667b
apaszke Include MGPU in grid/blockspec tutorial + use proper note/warning for…
26a5bce9
jakevdp jnp.ldexp: avoid overflow for large exponent
541c1169
yashk2810 Initial commit of non-experimental `shard_map`.
d3e19fea
bythew3i [ragged-paged-attn] Autotune set up and increase page size to avoid S…
6321af49
jakevdp DOC: link to ai-stack tutorials from JAX's front page
e0295f7d
hanrach9 Skip unary_ops_accuracy test for TPU version 7 and above.
36a6c12b
charleshofer Account for versioned clang binaries
5253e99f
yashk2810 Expose `jax.shard_map` as the public non-experimental API and move sh…
330f414a
justinjfu [Pallas] Make Pallas PRNG more robust by improving the Mosaic seed br…
d35c72df
yashk2810 Make `set_mesh` thread local so that it behaves exactly like `use_mesh`
fbbbb7b6
rajasekharporeddy Replace reference to jax.readthedocs.io with docs.jax.dev in jax._src…
d62d667b
Google-ML-Automation Update XLA dependency to use revision
fa4af688
mrodden Update install docs for AMDGPU
f58f0753
dfm Inline literals while tracing instead of in a separate pass.
58ea6785
dfm Handle DShapedArray in caller.
830ffaa7
Google-ML-Automation Fix typo "dataclasss".
964e8292
hawkinsp Fix some typos.
7880c58d
hawkinsp Fix mypy errors related to jaxlib.
4e950d9c
jakevdp Cleanup: don't redefine softmax in jax.random
4a5375f6
bchetioui [Pallas/Mosaic GPU] Allow explicit `smem` aliasing.
acedda03
yashk2810 Add a pvary on the `ones` we create in grad because the `ans` in the …
4853aa47
zhenying-liu Add a document for activation and parameter offloading
ab5f6706
vfdev-5 TSAN FT CI, install cython from nightly wheels
2b67b740
hawkinsp Move contents of jaxlib/xla into jaxlib/
2a245036
dfm Reverts 49e25c6167806ba90efe8370fb04db3f4966437c
9923ba80
yhtang Automatically initialize distributed runs in K8s indexed jobs
4bb2e2a5
yashk2810 Skip host_offloading notebook from running
c08fd51f
jakevdp jax.numpy: make standard input utilities respect __jax_array__
466d8c49
yashk2810 internal change
b554a65e
bixia1 Set is_custom field for custom_partitioning_sharding_rule so that rul…
8f82928b
yashk2810 Remove config `_check_rep` to `_check_vma` and the kwarg in shard_map…
8f87dc19
superbobry [pallas] `pl.pallas_call` no longer allows `compiler_params=` to be a…
811342ca
yashk2810 Migrate jax.experimental.shard_map to jax.shard_map in internal JAX a…
57c5175e
superbobry [pallas] Added a note on the recent `compiler_params=` change to the …
44400ff5
superbobry [pallas] Fixed the type of `MemoryRef.dtype`
bba6c56f
dimitar-asenov [Mosaic GPU] Introduce a dedicated `DialectBarrierRef` and handle war…
4df5a5e8
Google-ML-Automation Update XLA dependency to use revision
2a535e45
dfm Update FFI tutorial to new shard_map.
a29196af
yashk2810 Change pallas/distributed.{ipynb,md} to use jax.shard_map
f25bed71
dfm Fix shardy sharding rule for SVD decomposition.
fc62fa19
tomnatan30 #sdy avoid an extra call to `mlir.module_to_bytecode` and use `mlir::…
5f08cd04
Google-ML-Automation Refactor LocalMask to inherit from _ComputableMask
fa56b7b3
superbobry [pallas:mosaic] Fixed a bug in `pl.debug_print` lowering
c1630e1d
yashk2810 Add `__str__` to `UnshapedArray` so that whenever we `print(aval)`, w…
6cde5323
superbobry Ran pyupgrade --py310-plus on .pyi files in jaxlib
9e937fc0
yashk2810 Make inspect_array_sharding inside shard_map work with `check_vma=Tru…
d1523105
jakevdp Cleanup: remove superfluous jax.numpy utility
34d8d4ff
yashk2810 Remove PositionalSharding from JAX from almost all places except for …
718c103a
superbobry [pallas] Removed the deprecated `jax.experimental.pallas.gpu` module
531c5d4c
Rifur13 [Mosaic GPU] Adding a deterministic backwards pass to the pallas MGPU…
a0dfbcd2
dfm Re-land: "Don't recompute source_info for each tracer during staging".
d1cb3016
mattjj Update README.md to remove jax.pmap
ce4c923c
mattjj use :check: and :x: emojis in readme table
5a1be75b
mwhittaker Added new multi-controller JAX guide.
e546ad98
yashk2810 Update the sharding table to be the same everywhere
bc601c40
rdyro fix: in jax.asarray check default_device instead of default_backend
bf5ec8cd
yashk2810 Set __module__ on NamedSharding, SingleDeviceSharding and PmapShardin…
f4debd71
Google-ML-Automation Add mosaic tests to optional B200 GPU presubmit.
76dc1e7a
mattjj [readme] include teaser example of shardings
b4ea9e91
yashk2810 Move sharding table above the code example
28716264
vfdev-5 Updated TSAN suppresssions files to get traceback of crashed process
711e26c1
rickeylev fix: make build.py use /usr/bin/env python3 as shebang
a10ad049
Google-ML-Automation Fix typo in Bazel command for Mosaic tests.
a186d792
Google-ML-Automation [Pallas/Fuser] Bugfix for broadcasting, lax.slice_p, and lax.dynamic_…
3a451fff
danielsuo Update `debug_info` tests for `use_direct_linearize`.
6ce2bb55
mattjj [scan] when a carry is read-only, move it to be a const
f7611ef5
yashk2810 Fix shard_map's direct linearize after vma has been turned on
aea756d7
superbobry [pallas] Use `*MemorySpace` aliases
7e7e3e5a
sharadmv [Pallas TPU] Introduce a BoundedSlice block shape type
3b155efa
WillFroom [JAX:Sparse] Add BCSR benchmarks to sparse_benchmarks.py.
537b8754
apaszke [Mosaic GPU] Fix up the Blackwell code to match changes in the CUDA r…
582ab0c1
apaszke [Mosaic TPU] Allow indexing refs with narrow integers
923364ea
apaszke [Mosaic TPU] Allow simultaneous column and row shifts
2e5fa04d
Google-ML-Automation Update XLA dependency to use revision
526ccbd9
apaszke [Mosaic TPU] Relax test restrictions after improving DMA support for …
5baa2a5c
tomnatan30 #jax #sdy add a Shardy config to mock_gpu_topology_test.
953379e3
apaszke [Pallas] Remove leftover debug=True in Pallas tests
94bf5d0c
hawkinsp [JAX] Remove xla_client_test.
7dd98cb4
MichaelHudgins [CI] Correct the optional GPU presubmit to work properly with workflo…
24543a37
tomnatan30 #jax #sdy pass the value of `use_shardy_partitioner` to `get_compile_…
abb8999f
apaszke [Mosaic GPU] Wire up the plugin that contains the MGPU custom call in…
85b7e8fc
dfm Make callbacks work in unrolled loops on TPU.
6156b6fa
hawkinsp [XLA:Python] [JAX] Change JAX to use the _profiler module defined in …
d300eb87
yashk2810 Raise a better error in `jax.linearize` when vma's don't match betwee…
af011cf4
yashk2810 Allow decorator factory pattern for `jax.shard_map` i.e. `@jax.shard_…
0e47af50
tomnatan30 change build target
6bb73c22
mattjj roll back #28210 because it broke an internal test
56667fd6
WindQAQ [Mosaic] Support bf16 select and cmp <= TPUv4.
e80c4222
jakevdp jax.numpy: move linspace & friends into array_creation
3a7728f2
vfdev-5 Minor optim in profiler on session stop and export
4ef4c1ee
yashk2810 Replace auto tracking with manual axis names tracking internally in s…
c5e43670
carlosgmartin Edit print_environment_info to print environment variables that start…
9b38623e
vam-google Fix wrong shard_map import
01ddb2f9
yashk2810 Add conv_general_dilated sharding rule
0b7befdf
sharadmv [Pallas] Fix index_map equality checks
9434f45c
yashk2810 Remove `rep` completely from shard_map (rip rep from shmap -- RIP). T…
d1e0e58b
Google-ML-Automation Add multiaccelerator H100 tests to optional GPU presubmit.
90e985c7
mattjj Reverts fbb0cbbc6bfe30941076a365a427e12efbb5253b
8a72e754
ezhulenev [jax] Switch lapack kernels to builtin FFI attrs decoding
07ca41f5
hawkinsp [XLA:Python] [JAX] Move XlaBuilder bindings out of JAX and into XLA:P…
b3e42331
hawkinsp [JAX] [XLA:Python] Move ShapeIndex bindings out of JAX and into XLA.
2fb49804
yashk2810 Remove PositionalSharding from JAX from almost all places except for …
718c103a
superbobry [pallas] Removed the deprecated `jax.experimental.pallas.gpu` module
531c5d4c
Rifur13 [Mosaic GPU] Adding a deterministic backwards pass to the pallas MGPU…
a0dfbcd2
dfm Re-land: "Don't recompute source_info for each tracer during staging".
d1cb3016
mattjj Update README.md to remove jax.pmap
ce4c923c
mattjj use :check: and :x: emojis in readme table
5a1be75b
mwhittaker Added new multi-controller JAX guide.
e546ad98
yashk2810 Update the sharding table to be the same everywhere
bc601c40
rdyro fix: in jax.asarray check default_device instead of default_backend
bf5ec8cd
yashk2810 Set __module__ on NamedSharding, SingleDeviceSharding and PmapShardin…
f4debd71
Google-ML-Automation Add mosaic tests to optional B200 GPU presubmit.
76dc1e7a
mattjj [readme] include teaser example of shardings
b4ea9e91
vfdev-5 Updated TSAN suppresssions files to get traceback of crashed process
711e26c1
rickeylev fix: make build.py use /usr/bin/env python3 as shebang
a10ad049
Google-ML-Automation Fix typo in Bazel command for Mosaic tests.
a186d792
Google-ML-Automation [Pallas/Fuser] Bugfix for broadcasting, lax.slice_p, and lax.dynamic_…
3a451fff
danielsuo Update `debug_info` tests for `use_direct_linearize`.
6ce2bb55
mattjj [scan] when a carry is read-only, move it to be a const
f7611ef5
yashk2810 Fix shard_map's direct linearize after vma has been turned on
aea756d7
superbobry [pallas] Use `*MemorySpace` aliases
7e7e3e5a
sharadmv [Pallas TPU] Introduce a BoundedSlice block shape type
3b155efa
WillFroom [JAX:Sparse] Add BCSR benchmarks to sparse_benchmarks.py.
537b8754
apaszke [Mosaic GPU] Fix up the Blackwell code to match changes in the CUDA r…
582ab0c1
apaszke [Mosaic TPU] Allow indexing refs with narrow integers
923364ea
Google-ML-Automation Update XLA dependency to use revision
526ccbd9
apaszke [Mosaic TPU] Relax test restrictions after improving DMA support for …
5baa2a5c
tomnatan30 #jax #sdy add a Shardy config to mock_gpu_topology_test.
953379e3
apaszke [Pallas] Remove leftover debug=True in Pallas tests
94bf5d0c
hawkinsp [JAX] Remove xla_client_test.
7dd98cb4
tomnatan30 #jax #sdy pass the value of `use_shardy_partitioner` to `get_compile_…
abb8999f
apaszke [Mosaic GPU] Wire up the plugin that contains the MGPU custom call in…
85b7e8fc
dfm Make callbacks work in unrolled loops on TPU.
6156b6fa
hawkinsp [XLA:Python] [JAX] Change JAX to use the _profiler module defined in …
d300eb87
yashk2810 Raise a better error in `jax.linearize` when vma's don't match betwee…
af011cf4
tomnatan30 change build target
6bb73c22
mattjj roll back #28210 because it broke an internal test
56667fd6
WindQAQ [Mosaic] Support bf16 select and cmp <= TPUv4.
e80c4222
jakevdp jax.numpy: move linspace & friends into array_creation
3a7728f2
vfdev-5 Minor optim in profiler on session stop and export
4ef4c1ee
yashk2810 Replace auto tracking with manual axis names tracking internally in s…
c5e43670
carlosgmartin Edit print_environment_info to print environment variables that start…
9b38623e
sharadmv [Pallas] Fix index_map equality checks
9434f45c
yashk2810 Remove `rep` completely from shard_map (rip rep from shmap -- RIP). T…
d1e0e58b
Google-ML-Automation Add multiaccelerator H100 tests to optional GPU presubmit.
90e985c7
mattjj Reverts fbb0cbbc6bfe30941076a365a427e12efbb5253b
8a72e754
hawkinsp [XLA:Python] [JAX] Move XlaBuilder bindings out of JAX and into XLA:P…
b3e42331
Google-ML-Automation Update XLA dependency to use revision
94a843a3
Google-ML-Automation Update XLA dependency to use revision
af568312
dimitar-asenov [Mosaic GPU] Add a `BroadcastInDim` op in the mosaic mlir dialect.
906cb417
dimitar-asenov [Mosaic GPU] Handle WGMMA_ROW_LAYOUT in `vector_store` lowering and i…
6d49c8f5
bchetioui [Mosaic GPU] Delete `gpu.binary` after lowering.
8138dccf
apaszke [Mosaic GPU] Make sure the tests all pass on B200 (esp. in our CI)
4cbef87e
apaszke [Pallas:MGPU] Increase shard_coaunt for mgpu_attention_test to avoid …
404f7bc8
gflegar [Mosaic GPU] Relax tolerance for WGMMATest.test_narrow_n
c52ade7b
apaszke [Mosaic GPU] Fix a skip condition to avoid problems with nightly jaxl…
a835750d
nitins17 Fix formatting error in matrix exclude strategy
d7d545b3
mattbahr implement hyp2f1
d8cdedb2
Google-ML-Automation Update `build.py` to avoid duplication of bazel options both in bazel…
ca971c34
WindQAQ [Mosaic] Remove Python pipeline.
82597a9e
dfm Relax constraints in jnp.vectorize for output shapes with default sig…
4523d3ab
WillFroom [JAX:sparse] Fix bcsr.from_bcoo to use the index_dtype of the input B…
34773fdc
hawkinsp Fix warning in mosaic/pipeline.py under Python 3.12.
6d6ae50c
jakevdp Ensure __jax_array__ works properly with JIT disabled.
89ff6932
yashk2810 Allow 2x2x2 topologies with v6e
e3b45ad5
apaszke Fix scaled_matmul_stablehlo_test in x64 mode
4abcfb9e
dfm Handle extended dtypes in FFI lowering.
21745c1a
apaszke [Mosaic TPU] Add support for narrow integer `arith.constant`s in kernels
5d837dd0
emilyfertig Add sharding devices to XlaCompileOptions and plumb them through from…
90ab27d6
yashk2810 Add `standard_insert_pvary` support to `reduce`.
57d7a117
kanglant [JAX] Add a python 3.14 requirement lock file and update WORKSPACE
1a7ce898
dependabot[bot] Bump actions/setup-python from 5.5.0 to 5.6.0
3cbc5a3e
gentlelovebear [jax/docs] - Fix link in quickstart.md
57a6748e
hawkinsp Disable logsumexp test under SciPy 1.15.
42aa3ebc
rickeylev chore: load py_library from rules_python
f0d02bf7
MichaelHudgins Change nightly installation instructions and CI to use new package in…
4ac29b92
mduszyk Improving PyTorch data loading notebook
b2400141
nitins17 Disable old JAX nightly builds
7eb3f5bb
jburnim Small changes for easier testing of kernels with TPU interpret mode.
58e75bf9
MichaelHudgins Update JAX nightly index usage
a1fe4ca5
yashk2810 Enter into correct mesh context in shard_map in `_partial_eval_jaxpr_…
727fe84c
changhuilin Add GetAllocatorStats() method for device.
ba029782
dimitar-asenov [Mosaic GPU] Add layout inference and lowering for `vector.MultiDimRe…
e949ab4c
apaszke [Mosaic GPU] Add suport for packed TMEM
fdc4209f
nitins17 Add note to changelog about nightly package switchover
1d76b6d0
nitins17 Allow nightly workflows to halt for connection
80e89e86
Google-ML-Automation Update XLA dependency to use revision
65f0e17b
dfm Re-land "Inline literals while tracing instead of in a separate pass".
f0a479e1
dfm Remove unused parameters in emit_python_callback.
641788a3
dfm Remove deprecated GPU linalg kernels after compatibility period.
5f15f770
armandpicard fix: use backend to call xb.process_count in _raise_warnings_or_error…
7472ea3f
dfm Fix deprecation warnings in cudnn scaled matmul.
4b21c729
Google-ML-Automation Support batch_axis being int rather than Sequence[int] in initializers.
af6d3d70
hawkinsp [JAX] Remove jaxlib/xla directory.
0f2a89dd
Melody-coder923 SPMD with jax.jit, namedsharding and partition spec
ad051116
hawkinsp Fix test failure in autodidax due to compile() API change.
af336765
hawkinsp PR #28391: Add version guard in autodidax test.
88d6a8fc
hawkinsp Update autodidax.ipynb and autodidax.md
aa7bdd67
dfm Show Literal's aval when pretty printing.
cbe2b203
hawkinsp Relax test tolerance for PureCallbackTest.test_can_take_grad_of_pure_…
6d53a573
hawkinsp Set PYTHONWARNINGS=error for jax_multiplatform_test.
e7fd33e3
rickeylev chore: load py_library from rules_python
f0d02bf7
dimitar-asenov [Mosaic GPU] Add layout inference and lowering for `vector.MultiDimRe…
e949ab4c
apaszke [Mosaic GPU] Add suport for packed TMEM
fdc4209f
dfm Re-land "Inline literals while tracing instead of in a separate pass".
f0a479e1
hawkinsp Round-robin pytest tests across GPUs.
f8b77865
sharadmv [Pallas TPU] Add better support for BoundedSlice and other BlockDim t…
ff8341d6
charleshofer charleshofer force pushed from e9c8fdcb to 7c833647 239 days ago
charleshofer charleshofer merged 7c833647 into rocm-main 239 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone