jax
CI: 03/18/25 upstream sync
#294
Merged
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
41
Changes
View On
GitHub
CI: 03/18/25 upstream sync
#294
charleshofer
merged 41 commits into
rocm-main
from
ci-upstream-sync-151_1
Use `lax.top_k` instead of `jnp.argsort` in Gumbel top-k trick for we…
4a82fe94
add experimental lax.optimization_barrier autodiff rules
dadc68b6
Add a C++ implementation of a toplogical sort.
14cb7453
Merge pull request #27174 from mattjj:opt-barrier-ad-rules
7db59cdc
mixing modes
3c0027af
Merge pull request #27177 from jax-ml:mixing_modes
d07d642d
Support error checking in explicit mode
9b0ace4a
Update XLA dependency to use revision
f360e191
Better docs for jax.lax add/sub/mul/div
de8b0564
Change the way that batching.spec_types is updated.
466ef6a1
Update XLA dependency to use revision
e8b683ae
Merge pull request #27176 from jakevdp:lax-docs
761b35c5
[Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16…
2bdd9c87
[Mosaic GPU] Add support for changing the layout before the upcast
89b21de6
[pallas:mosaic_gpu] `jnp.reduce_sum` now works for >1D arrays
a7e5eaee
Update XLA dependency to use revision
55812c5d
Removed trivial docstrings from JAX tests
0ff23404
[Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes t…
3649da56
Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt
031614c2
Merge pull request #27157 from mar-muel:improve-random-choice-perform…
de9ad6ba
Add replace option to random.categorical to enable sampling without r…
3f59fa68
[Mosaic GPU] Add initial transform inference rules for `vector.{load,…
9a686e0b
Remove code that preserved _original_py_fns on C++ classes.
be5d13af
Merge pull request #26980 from carlosgmartin:categorical_replace
ebcae0d3
Replace cached function get_replicated_hlo_sharding() with a constant.
20658fab
Fix error in pallas tutorial
4f704713
Add B200 testing to continuous workflow
ecf7fde7
Merge pull request #27164 from MichaelHudgins:a4-testing
b74b16f9
Compute tile index using tile-based coordinates
b4966130
[pallas] `pallas_call_p` is now parameterized by a mesh
051687dc
Enable `jax.device_put` to a sharding with no local devices.
8c351917
Replace the uses of `PjRtClient::Compile()` with `PjRtClient::Compile…
f174b00f
Allow pspec to be passed to device_put if there is a mesh in the surr…
549973de
[Mosaic GPU] Remove sub-byte conversion restriction
34cd5b0d
[mosaic_gpu] Force flush all cupti activity, then unsubscribe.
38d52a19
[Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGM…
d4bd2570
[Mosaic GPU] Add transform inference rule for `mgpu.slice_smem`.
ba2f7c9a
Update XLA dependency to use revision
7a459f0e
[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
8da93249
[Mosaic GPU] Raise a `NotImplementedError` if `swizzle=16`.
1e36cbe5
github-actions
requested a review
1 year ago
github-actions[bot]
enabled auto-merge
1 year ago
disabled auto-merge
1 year ago
Manually disabled by user
Merge branch 'rocm-main' into ci-upstream-sync-151_1
c7b407c9
charleshofer
approved these changes on 2025-03-18
charleshofer
merged
c46b4fc0
into rocm-main
1 year ago
charleshofer
deleted the ci-upstream-sync-151_1 branch
1 year ago
Login to write a write a comment.
Login via GitHub
Reviewers
charleshofer
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub