jax
CI: 01/27/25 upstream sync
#213
Merged
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
298
Changes
View On
GitHub
CI: 01/27/25 upstream sync
#213
github-actions
merged 298 commits into
rocm-main
from
ci-upstream-sync-97_1
[sharding_in_types] Introduce `auto_mode`, `user_mode`, `auto_mode_ct…
a817f532
Fix pmap with sharded typed prng key
aeac6b03
Update XLA dependency to use revision
4e729316
[shardy] Fix cases in shardy where you have a nullary function with p…
6b253b2f
Update XLA dependency to use revision
ed349e45
[Mosaic TPU][NFC] Remove redundant num_subelems attribute from Create…
09302899
[Mosaic TPU] Enable non-sublane-aligned bf16 2D load/stores for earli…
78520455
Use thread-safe initialization of LAPACK kernels.
91ffb640
Add JAX unit test for Shardy which causes the compiler to introduce t…
c14e5b43
[better_errors] Ensure that tracer errors in for_loop points to use code
3faff78c
[Mosaic TPU] Add support for arith.fptosi with non-32bit source and t…
391bad8f
[Pallas TPU] Skip cast test incompatible with older libtpu builds
aa51f2af
Update XLA dependency to use revision
9ee25e7e
Internal change.
e72c1484
Even more linearize fixes
96769f96
Merge pull request #25833 from jax-ml:more-linearize-fixes
dabe27bc
[Mosaic GPU] Fix layout API bugs.
f69592ae
[Pallas TPU] Add vector support to `pl.debug_print`
9ba1fd28
enable partitionable threefry by default
a60ead6f
Merge pull request #25798 from gnecula:fix_fori_error
4f2f5fa5
#sdy dynamically choose which `custom_partitioning` API to use based …
74e912c3
[jax2tf] Fix bitrot in docs
36533b9e
[better_errors] Add debug_info to DynamicJaxprTrace and JaxprStackFrame
b30df36d
Mention expected tangent aval in error message, see #25517.
7d11d12b
Merge pull request #25875 from jax-ml:issue-25517
29a3dded
Pin actions/checkout to a commit
a1bbad68
[sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_t…
c72ed260
Merge pull request #25827 from gnecula:debug_info_2
ee724565
Remove dead codepaths now that MemorySpaceDescription works in OSS
b7e06f19
Merge pull request #25872 from gnecula:jax2tf_doc
f270739f
Update XLA dependency to use revision
2408fb7d
Fix remat bug on primitives with multiple outputs.
b6acb9cb
Reverts 391bad8ff59c07c8fad7b8ce05cd0e29dee4cf1a
f1b894d1
Indexing: avoid dynamic_slice when mode='clip'
54fbf0b3
Rename test configs to include GPU variants more consistently.
f122f17b
[Mosaic TPU] Append dump id to timestamp to make dump list ordered
6851700e
[Docs] Remove `--xla_gpu_enable_triton_softmax_fusion` from docs
57a259f4
[Mosaic GPU] Enable x64 tests for mosaic gpu.
ff5cb811
Add Github action workflows for running continuous tests with Pytest
c78487d2
Temporarily disable GQA attention tests on GPU, which were broken by …
d1810b42
[Pallas] Fix GQA triton kernel test.
cc9f6e75
[ROCm] Implement RNN support
fe68eb8b
[pallas][mosaic kernel export] Add initial support for exporting a dy…
c18492be
[pallas] Fix bad rebase, deleted lowering for a print
c4406d27
[better_errors] More cleanup
f9dfe7f6
[Pallas TPU] Temporarily strengthen restrictions on Pallas tests
aa19f9c4
[Mosaic GPU][NFC] Address some previous stylistic comments.
cdf490a5
[pallas:mosaic_gpu] Fixed a crash in MLIR Python bindings
afcb21dd
More linearization fixes
9fe553ca
Merge pull request #25880 from jakevdp:fix-gather
2e5e4799
Add gfx12xx archs
435edf1f
Merge pull request #25876 from gnecula:debug_info_3
70c1ee5d
Jax: Stop returning a list of cost-analyses.
2d72e8de
Update XLA dependency to use revision
51f23100
Merge pull request #25864 from jax-ml:yet-more-linearization-fixes
ca012d7a
Merge pull request #25755 from ROCm:ci_rnn_final-upstream
41993fdb
Update the JAX version to 0.5.0.
3a8f31aa
Make utils for reporting function name work with `functools.partial` …
f7d097f7
Merge pull request #25911 from hawkinsp:version
2fa10020
Merge pull request #25906 from ROCm:ci_add_new_gfx-upstream
cf67e28f
Move halt for testing step to be just before running tests
8a053af1
[Mosaic] Add a macro to convert abseil StatusOr to LLVM FailureOr
d3ba1eb3
Fix run_multi_gpu script multi-gpu issue and refactor code
8e88adcd
[Mosaic] Allow passing `ApplyVectorLayoutCtx` to tpu.apply_layout_op.
4a9cc9ff
[Pallas TPU] Add helpers file with copy_ref function
0ac63157
Merge pull request #25917 from ROCm:ci_fix_multi_gpu_test_logic-upstream
9a60e6fc
[mosaic] Extracted serialization pass traversal logic into a reusable…
4221f109
[Mosaic GPU] Allow querying layouts from a `FuncOp`'s block arguments…
bc7204f0
[MosaicGPU] Extract code into a new method `BarrierRef.from_dialect_b…
22417ae2
[Mosaic TPU] Add support for packing to 16-bit integers on TPUv4
ef4dbd9c
[MosaicGPU] Move `gpu_address_space_to_nvptx` inside `utils.py` and u…
ce03cf97
[sharding_in_types] Expand reshape's sharding rule to add support for…
c6b5ac5c
[Mosaic GPU][NFC] Simplify and clean up layout inference tests to use…
3366c927
[MosaicGPU] Remove the single_thread context from top-level dialect c…
24884071
[Mosaic GPU] Add layout inference for splat `arith.ConstantOp`s and `…
d3bf2433
[Mosaic GPU][NFC] Clean up import to align with stylistic guidance.
6746d633
[MosaicGPU] Cleanup imports in dialect_lowering.py
5e27efd0
[Mosaic TPU] Improve support for int16->int32 casts in TPUv4
8954e71d
Make `result_handler` of `_DeferredShardArg` a method instead of a pr…
0df4475a
[Mosaic] Fix infer/apply extensions.
5c020ee3
Update XLA dependency to use revision
994c3f59
[sharding_in_types] If an indexing operation hits into `gather_p`, er…
b23c4237
Allow resharding between tokens on a single device
f2f552c1
Store GCS upload URI as a step output
5e52031d
Add quotes to pip commands in docs around option install for zsh
aa9cea0a
[Mosaic TPU] Use large to compact 2nd minor retiling for conversions …
bd22bfef
Add ensure_arraylike utility for lax.numpy implementations
4c926c8d
Fixing bwd attention test tolerance level
2cdd9b7d
Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Coll…
49224d6c
Rename out_type -> out_sharding parameter on einsum
97cd7483
[sharding_in_types] Rename `.at[...].get(out_spec)` to `.at[...].get(…
af667199
[mosaic_gpu] Added a serialization pass
d34c40f6
[Mosaic GPU] Delete unused declarations of `mosaic_gpu_memcpy_async_h…
d3be190e
Update XLA dependency to use revision
7cac76d3
[sharding_in_types] Error out for reshape for splits like this: `(4, …
ce85b898
Reverts f1b894d14a28ac22a037fb79177b991275c75a18
a527aba6
Tweak documentation of jnp.cov to include scalar return for M = 1
df6140e8
Release JAX 0.5.0
c25fb92c
Merge pull request #25952 from johannahaffner:patch-1
a4a657bc
[key reuse] fix signature for device_put
f83175fc
Merge pull request #25957 from jakevdp:device-put-key-reuse
232662aa
Add performance note at the top of sparse docs
3141453b
Merge pull request #25642 from Rifur13:numerical_stability
4d20052f
Merge pull request #25936 from jakevdp:ensure-arraylike
bda52c36
Update XLA dependency to use revision
318764b8
Merge branch 'release/0.5.0' into main
b0a71357
Update version numbers after 0.5.0 release
9fa29122
Merge pull request #25958 from jakevdp:sparse-warning
fe6172d1
[sharding_in_types] Rename `sharding_cast` to `mesh_cast` and add a f…
695c02b1
Merge pull request #25961 from hawkinsp:postrelease
093dd9f4
Use ensure_arraylike utility in jax.numpy.linalg
7d81547f
Remove code that supported jaxlib < 0.5.
efab6945
Merge pull request #25962 from hawkinsp:oldcode
783d03c5
Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_a…
12b59f8e
Set timeout for artifact building and "run tests" steps
12beb00b
Add workflow for testing nightly/release artifacts
9fb29766
internal: check integer overflow in lax.asarray
45a35204
Expose hidden_axes via jax namespace as public API. Also mention it a…
c7f8d17f
Remove CUDA rpaths from jaxlib build.
034e967e
add f-string to mosaic memory space error msg
41738427
Add a sharding rule for `reduce_precision_p` and properly thread eqn.…
36daf369
Remove memories flag now that JAX 0.5.0 has been released since it al…
5a068da6
Update XLA dependency to use revision
b0117366
Merge pull request #25976 from arvoelke:fix-memory-space-error
cc38d8c1
Merge pull request #25969 from jakevdp:fix-util
aed9c6f1
Speed up attention kernel by using exp2
1f595063
Update docs of callbacks
4f8699c8
Update XLA dependency to use revision
d415c80b
Merge branch 'jax-ml:main' into use_exp2
e7db4d50
[better_errors] Improvements in propagation of debugging info
dcf72b01
Merge pull request #25916 from gnecula:debug_info_4
ce48f647
Update XLA dependency to use revision
a43edb46
[Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
543dd947
[better_errors] Finally remove api_util.debug_info.
4fd0bb05
[better_errors] Refactor debug info tests
e5d89e73
Add `reshard` API in experimental. Currently for sharding_in_types we…
799eb98c
Merge pull request #25988 from gnecula:debug_info_tests
e41f4caa
Merge pull request #25984 from Rifur13:use_exp2
7f19b345
Don't allow users to query `tracer.sharding` even under sharding in t…
d50d1e2c
Fix the backwards pass and support more block sizes.
b8b9f2bc
[sharding_in_types] Move the calculation of new_mesh inside the decor…
a943ebf4
Make sure reshard and mesh_cast behave properly under eager mode
bba5ada5
[Mosaic GPU] Do not use `mgpu` in `wgmma.py`
0a89760c
[Mosaic GPU] Add support for converting all fragmented layouts to ir …
f89accc5
[better_errors] Ensure debug_info.arg_names is never None.
3f73f7b0
[Pallas] Improve testing for casts from narrow types + test int4
3c8cf3c9
[mosaic_gpu] a function for bitwidth as well as bytewidht.
c4643c61
Make `make_mesh` take `visible_axes`, `hidden_axes` and `collective_a…
d30476a0
Update XLA dependency to use revision
aaa7e922
Merge pull request #25982 from roth-jakob:warning_callbacks
3b5b9816
[mosaic_gpu] Removed debug prints in `jax._src.lib.mosaic_gpu`
4c363766
Merge pull request #25647 from Rifur13:bwd_pass
70a5175d
Add part (non-quantized K/V pages) of paged_attention_kernel tests ba…
96a3ed36
Add an option to simplify `keystr` output and use a custom separator.
7f43316e
[Mosaic] Remove hardcoded TARGET_SHAPE and align Python/C++ APIs.
79bd72e2
Remove unnecessary use of xla_client.OpMetadata class.
afb750cf
Add support for constants in the decomposition of `lax.composite`s.
aa8c0010
Error out if contracting dimensions are sharded and ask the user to p…
051861bb
jax.numpy setops: use ensure_arraylike & avoid asarray
a69f9dcc
Remove `device_context` from `trace_context` because we don't need it…
3aa55992
Remove meaningless template keywords.
54bb7f5d
[Mosaic TPU] Emulate converting x16 vector to mask if mask packing is…
908df65a
Fix cache init when JAX Array is created early (#25768)
f87c94db
Merge pull request #25992 from gnecula:debug_info_arg_names
e304e9ea
[better_errors] Expand the tests for debug_info
849ccc97
[XLA:CPU] Use central difference to calculate numerical gradient
dc16721b
Merge pull request #26028 from gnecula:debug_info_more_tests
d6028153
Update XLA dependency to use revision
5c6b6c12
Integrate LLVM at llvm/llvm-project@d33e33fde770
6c76cc4e
Replace "gpu" with "cuda" to be specific about the type of gpu tests …
14029c67
[Mosaic GPU] Add manual consumed barrier handling to WS pipeline.
10bb38bb
Merge pull request #26024 from jakevdp:setops-args
fc935608
Set __slots__ on core.Trace subclasses.
f4adcc65
Add e8m0fnu support by conditional dtype.
638c6ae0
fix error tolerance
fe37afe9
Merge pull request #26044 from hawkinsp:slots
1b6080d9
Merge pull request #25889 from Stella-S-Yan:cache_reset
f6243ff8
DOC: mention adoption of EffVer in JEP
423be16e
Remove axis_name from unmapped_aval
23d360bd
Update debugging Pallas g3doc to remove text about scalar printing re…
64e9b07e
Merge pull request #26050 from jakevdp:effver-adoption
0fccabcd
Speed up name stack printing.
cd51e9dd
[sharding_in_types] Make `vmap` work with shard_map + pallas
704b2e5f
Add job that runs Bazel single accelerator and multi-accelerator CUDA…
9aad6a68
Support axis_index using a nested shard_map instead of iota with full…
f3e27b6c
[Mosaic GPU] Implement basic WGMMAFragLayout inference and propagation
3a411d88
[Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect
6b747b41
[Mosaic GPU] Simplify enums in the MLIR Mosaic GPU dialect.
f57d603c
[Mosaic GPU] Remove an unnecessary restriction in the `vector.store` …
6f609926
Fix typo and improve readability of workflow documentation
e8d40ff1
Remove code that sets CUDA rpath option from .bazelrc.
3b772b61
Remove axis_name from pmap_unmapped_aval_handlers
33aa088a
Don't calculate tracing debug_info twice
6b95ad0a
[jax:custom_partitioning] Support SdyShardingRule with multiple leading
0c3de93b
internal: move more NumPy APIs to ensure_arraylike
23c1d629
Merge pull request #26060 from hawkinsp:cuda
4222c30c
#sdy unskip JAX Shardy tests that are already passing
db8c8fc3
Update XLA dependency to use revision
281ce646
Merge pull request #26063 from jakevdp:ensure-arraylike
f4313bb6
Add an experimental interface for customizing DCE behavior.
e3b3b913
Add DCE rules for custom_jvp and custom_vjp.
28d57335
[jax:custom_partitioning] Extend the custom partitioning API to accept a
e33cc428
Merge pull request #25116 from wenscarl:fp8_e8m0fnu
8442d64a
[JAX] Optimize array shard reordering
1d016962
Add jax.test_util to public API docs
458f6a6e
add docstrings for check_vjp and check_jvp
b2afb5bf
Move transfer python bindings into jax.
3864512b
[Mosaic TPU] Use vmask pack if possible for mask's bitwidth change an…
8e1f9568
Remove all MHLO uses, replace it with StableHLO
313e35a2
[export] Fix mis-used of NamedSharding in export tests
6dd12347
[Mosaic GPU] Add basic support for TMA with sub-byte types
7043b852
Don't pass dtype to lax_internal._zero
46d8cd2a
[mosaic_gpu] Fixed mosaic_gpu-serde pass registration
9ee7123c
Merge pull request #26076 from gnecula:fix_ps
20f02c49
Merge pull request #26046 from Cjkkkk:fix_cudnn_sdpa_dbias_error_tole…
1f23253b
Merge pull request #26072 from shoyer:test_util_doc
33ec6294
[Pallas:MGPU] Add helpers to make writing core_map kernels less verbose
c10b9b88
[Mosaic GPU] Add implementation of FA3 with pipeline emitter.
617e79f8
Adding GPU paged attention kernel
e0b38f4e
docs: outdated link for index update syntax
9e89ae7a
Update XLA dependency to use revision
9fb423ed
Merge pull request #26082 from Saransh-cpp:index-update-syntax-link
726abc9c
Merge pull request #25839 from Rifur13:paged
407c5b1c
Merge pull request #25963 from dfm:dce-custom-star
7a23d1d6
Don't computing forwarding information if we're going to inline.
cbc2d623
Use `backend._get_all_devices()` to validate devices.
08d81e45
Replace Hidden/Visible/Collective AxisTypes names with Auto/Explicit/…
d28c3fa4
Add new Bazel remote cache configs
89a9c6c2
Optimize the set_xla_metadata context manager.
184aefa4
improve make_array_from_single_device_arrays error
1fb4b93d
Merge pull request #26092 from mattjj:make-array-error
a8adf752
Update XLA dependency to use revision
55efd4b2
[better_errors] Add more debug info test coverage
e4d5427d
Merge pull request #26093 from gnecula:debug_info_tests1
381da3cf
Optimize implementation of the compute_on context manager.
77632791
Update XLA dependency to use revision
84921792
Optimize JaxprEqnContext context manager.
95cb0eb1
[better_errors] Refactor more uses of partial_eval.tracing_debug_info…
7361d173
Merge pull request #26097 from gnecula:debug_info_no_pe_debug_info
2a6accd6
[Mosaic GPU] Fix error message to make it clearer.
101f18d4
[Mosaic GPU] Handle the `swizzle` attribute in the lowering of `async…
a3a285dd
[better_errors] Refactor more uses of pe.tracing_debug_info (part 2)
878272ee
Merge pull request #26099 from gnecula:debug_info_no_pe_debug_info_2
9b5cb45b
[Mosaic GPU] Use a single instance of the `single_thread_predicate` i…
a0db6c5c
Rename debug_info_tests.py copied from api_test.py
c61401ab
Disable pytorch_interoperability_test under asan.
42fd586e
#sdy Enable more `shard_map` tests under Shardy.
21913b8e
github-actions[bot]
enabled auto-merge
360 days ago
Merge branch 'rocm-main' into ci-upstream-sync-97_1
63e6442b
Use hipfft XLA fix
30aada2b
Skip PallasCallRemoteDMAInterpretTest.test_interpret_remote_dma_pperm…
41ab12bf
charleshofer
requested a review
from
charleshofer
359 days ago
charleshofer
approved these changes on 2025-01-27
github-actions
merged
a366d414
into rocm-main
359 days ago
charleshofer
deleted the ci-upstream-sync-97_1 branch
359 days ago
Login to write a write a comment.
Login via GitHub
Reviewers
charleshofer
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub