jax
CI: 12/09/24 upstream sync
#175
Closed
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
136
Changes
View On
GitHub
CI: 12/09/24 upstream sync
#175
charleshofer
wants to merge 136 commits into
rocm-main
from
nighly-sync-09-12-2024
Use optimize='auto' for multi_dot.
236d4c60
Add exec_time_optimization_effort and memory_fitting_effort flags.
762301fc
Use next to tiny as smallest floating point value on Mac ARM
504c7387
Add version check for effort flags
83b54d97
Merge branch 'main' into add-optimization-effort-flags
c65ce4b0
Fix jnp.matmul return shape documentation
cd578d97
Save residuals in the decode attention pallas kernel
a4e742d2
[Pallas:MGPU] Make the shapes from the attention example more interes…
8a316195
Update the docstring of jax.lax.switch
bbc4a20c
[Pallas:MGPU] Fix a use-after-free in lowering
b1423a36
Merge pull request #25191 from houeland:patch-1
2e0474a5
jnp.reshape: raise TypeError when specifying newshape
a7039a27
Merge pull request #25117 from pearu:pearu/arcsin-mac-arm
6b029503
[Mosaic GPU] Automatically squash a >3D logical grid into a 3D physic…
784ebeab
Merge pull request #25179 from rajasekharporeddy:lax_switch
385328b5
Skip vecmat & matvec in NumPy tests.
f182aa8e
Merge pull request #25055 from dfm:multi-dot
46c748b9
[jax] Typing on common_devices_indices_map
c9a59022
Update XLA dependency to use revision
9f203017
Update Cloud TPU workflow with new build.py usage
0134fa83
[shape_poly] Remove obsolete part of the shape polymorphism documenta…
b3c405c2
Merge pull request #25216 from gnecula:poly_doc
908865f2
Add an option to deactivate automatic cluster detection in jax.distri…
6a8bbcba
Merge pull request #24964 from emilyfertig:emilyaf-deactivate-cluster…
c9c043cf
Merge pull request #25210 from jakevdp:fix-nightly
f2f02eea
[Mosaic GPU] Improve default kernel name and add option to customize
7bd81dbe
Fix some typos
f43fa9fc
[Pallas:MGPU] Add tests for attention with non-trivial batch size
0bb68f6a
[jax] Improve naming of `DotAlgorithmPreset` properties and simplify …
abf8f430
[jax] Make `DotAlgorithmPreset.supported_output_types` a function of …
a54319ec
Fix doc typo
4b6035ca
Fix nightly numpy test
2afc65a1
Fix missing quotes in local xla path
cc95327a
Merge pull request #25231 from jakevdp:fix-nightly-names
8c66cba4
Merge pull request #25162 from nireekshak:testbranch
dfa0dd70
Integrate Triton up to [9732c047](https://github.com/openai/triton/co…
c4d19ca8
Reverts a54319ec1886ed920d50cacf10e147a743888464
73962b74
[Pallas TPU] Enable test for `jnp.logical_not` because it's now suppo…
2dae81a8
Merge pull request #24748 from jakevdp:reshape-dep
d990dcf2
[Mosaic TPU] Support packed type matmul with arbitrary shapes.
9e5edb70
Update XLA dependency to use revision
ceeed909
Improve trace-time performance of jnp.isscalar
0140a98e
Add _raw_platform to work around extra platform normalization logic a…
fcf0b6d3
Use JAX's default device instead of jax.devices()[0], if set.
fd4b1608
Merge pull request #25237 from jakevdp:faster-isscalar
40122f7c
Fix indexing corner case with empty ellipses
f6f4ef06
Convert MSYS' Linux-like paths to Windows paths in JAX CI.
8c78c1e7
[Mosaic GPU] Add missing import.
cb2cf56e
[mgpu_pallas] Optionally pass default value instead of raising an err…
1ddba9b1
[mgpu_pallas] Allow loading scalars or indexing arrays from gmem usin…
3895e037
[Mosaic GPU] Add an optimization barrier
11090be0
Fix Windows portability problem in compilation cache test.
5a250097
Merge pull request #25246 from hawkinsp:win
09177cf7
Disable backwards compatibility test for Triton IR.
2ac26924
[pallas:mosaic_gpu] `emit_pipeline` no longer ignores transforms
12b45b32
[pallas:mosaic_gpu] Use `jax.tree_util.register_dataclass` for transf…
46eb77be
Disable JaxAotTest.test_topology_pjit_serialize on GPU, which fails i…
bdadc53e
Disable pgle_test on non-GPU platforms.
681b9c2e
Fix the broken behavior of not resetting the abstract_mesh and device…
653f6545
[Pallas] Update changelog for `pl.estimate_cost`
721b517e
Merge pull request #25239 from jakevdp:indexing
1da03791
Merge pull request #25006 from andportnoy:aportnoy/mosaic-gpu-kernel-…
fa6585de
CI: update array-api-tests to latest commit
8563449a
[sharding_in_types] Use `set_mesh` API to trigger sharding_in_types i…
9e2708eb
Merge pull request #25264 from jakevdp:update-array-api
222b2e75
Merge pull request #25199 from Rifur13:save_residuals
db97d7aa
[Pallas] Fix type annotation on TritonCompilerParams
1a3c9c44
Simply abstract_mesh and device_context context managers and handle e…
a735bf83
Remove obsolete deprecation
5ade371c
[Pallas] Pallas documentation cleanup
e05afefc
[Mosaic] Add extra memref_slice verification and a memory space check…
3990e05a
context manager methods for AbstractMesh to appease type checker.
208194f9
Merge pull request #25114 from jedborovik:add-optimization-effort-flags
182e5326
Update XLA dependency to use revision
28528d44
More thorough propagation of host linear layout. Currently linear lay…
f160df04
[Mosaic:TPU] Lift offset restrictions on single-row (1, 128) -> (8, 1…
10116874
remove vestigial ad.reducing_transposes table
6172a1f1
[Mosaic:TPU] Add relayout for adding minor implicit dim and relax som…
8163e74e
[AutoPGLE] Add multi-process test case
7214a3a8
[shape_poly] Fix the handling of __pow__ for symbolic dimensions
4e17bea9
[Mosaic GPU] Remove expect_wait from Barrier.wait
c965ffbf
[pallas:mosaic_gpu] Removed leftover debugging code
03861d43
Merge pull request #25276 from mattjj:remove-vestigial-reducing-trans…
39d73a68
Reverts 73962b740890a728295fa09f515dcf96cb820822
569c2a3c
[Mosaic TPU] Add support for modeling loads/stores and fix minor issu…
d5ead570
[Mosaic GPU] Always annotate block initialization in the profiles
d034680f
[pallas:mosaic_gpu] Do not store the grid mapping in `ModuleContext`
e5102957
[shape_poly] Remove some deprecated kwargs
5fe5206b
[pallas:mosaic_gpu] Removed unnecessarily strict check in `emit_pipel…
4a41aa0a
[export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of c…
3f5f3e1c
jax.numpy: require boolean dtype for where argument
29a8cce6
jnp.linalg.vector_norm: properly support multiple axes
aaaee63a
[JAX] Add end-to-end execution support in colocated Python API
e20a483b
Merge pull request #25271 from jakevdp:fix-vector-norm
a71f9a62
Merge pull request #25290 from jakevdp:reduction-where
f73fa7a7
[jax:custom_partitioning] Implement SdyShardingRule to support
2a4a0e8d
[Mosaic:TPU] Fix fully replicated relayout
23d5c10f
JAX release 0.4.36.
7e6620a5
[Pallas] Fix shard_axis in dma_start interpret mode rule.
259194a6
[Pallas] Fix shard_axis in dma_start interpret mode rule.
fd42b561
Merge pull request #25269 from justinjfu:pallas_docs_cleanup
d782b246
[Mosaic:TPU] Fix elementwise inference with i1s
651ab188
[pallas] fix jumble test flakiness
84f3f992
Merge branch 'release/0.4.36' after release
ab02bf87
Update XLA dependency to use revision
1ca8903a
Merge pull request #25296 from hawkinsp:postrelease
45159a2f
Merge pull request #25252 from gnecula:poly_power
9fc077a5
Bump JAX version after release.
ba626fa6
Activate Schur Decomposition to XLA's FFI
9081e85d
Set -Werror=mismatched-tags on clang.
fac1b1a7
[pallas:mosaic_gpu] `FragmentedArray.reduce_sum` now returns a `Fragm…
bae66000
[mosaic_gpu] Emit a slightly more informative error message in `Fragm…
08d31d0f
Fix error when swapping a ref with a trivial indexing transform.
af501356
[Pallas MGPU] Use multiple k/v_consumed_barriers in the attention kernel
8b656206
[Pallas MGPU] Disable XLA:GPU autotuning in attention tests
eda7506d
Don't look for CUDA files when building the ROCm wheel
0c6b967e
Merge pull request #25205 from jburnim:jburnim_swap_fix
72df8e0c
[Pallas] Add support for run_state to cost estimator.
641a1d53
Document cudaMallocAsync as an experimental feature.
a13b618c
Merge pull request #25082 from nouiz:doc_cuda_malloc_async
b6499e21
Add a flag to enable detailed timestamped logging of subprocess comma…
83c64b23
Update XLA dependency to use revision
baedb62b
Support transfer guard in broadcast_one_to_all(). Fixes https://githu…
861115ad
Temporarily allow bfloat16 dot algorithms on CPU.
1f4d184a
Update XLA dependency to use revision
ad00ee1e
Update XLA dependency to use revision
70623255
Fix type annotation for numpy.linalg.matrix_norm argument 'ord'.
efa35ea9
[pallas:triton] Add support for `DotAlgorithmPreset` `precision` argu…
3ec55c77
[pallas] Add `DotAlgorithmPreset` note to CHANGELOG.
a94474d0
[Mosaic TPU] Allow downgrading the IR during serialization for forwar…
adb2bf62
Activate Tridiagonal Reduction to XLA's FFI
d474feda
Merge pull request #25338 from carlosgmartin:fix_numpy_linalg_matrix_…
5a1c4c57
Ensured that JAX type checks under pytype on Python 3.12
1ac6b762
Merge pull request #25320 from ROCm:gh-9948-fix-kernel-build-upstream
cc258f5f
Remove dead code after minimum jaxlib version bump to v0.4.36.
79318a08
Merge branch 'rocm-main' into nighly-sync-09-12-2024
98958d01
charleshofer
closed this
1 year ago
charleshofer
deleted the nighly-sync-09-12-2024 branch
1 year ago
Login to write a write a comment.
Login via GitHub
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub