jax
CI: 11/22/24 upstream sync
#148
Merged
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
155
Changes
View On
GitHub
CI: 11/22/24 upstream sync
#148
charleshofer
merged 155 commits into
rocm-main
from
ci-upstream-sync-34_1
Add a link to Intel plugin for JAX
68428488
Add float8_e4m3 and float8_e3m4 types support
78da9fa4
Move Control Flow text from Sharp Bits into its own tutorial.
e6f6a8af
jnp.logaddexp2: simplify implementation
d823f172
Update array-api-tests commit
d0f36666
Merge pull request #24864 from jakevdp:logaddexp2
303b792a
Add numpy.put_along_axis.
1f114b1c
Merge pull request #24904 from jakevdp:array-api
04d339d6
cleanup: delete unused argument from internal reduction helper
4a3e1155
Update XLA dependency to use revision
c40d405e
Add shard map replication rule for ffi_call.
41a0493e
Merge pull request #24567 from Intel-tensorflow:minigoel/intel-plugin
8e292122
Merge pull request #24905 from jakevdp:old-arg
5764afb4
Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
4fe91645
Update array-api-tests commit
a115b2ce
Merge pull request #24881 from dfm:ffi-call-rep-rule
c6051b3e
Merge pull request #24907 from jakevdp:array-api
1c31860f
Merge pull request #24862 from emilyfertig:emilyaf-control-flow-tutorial
4511f0c6
[sharding_in_types] Handle collective axes in lowering rules more gen…
9a0e9e55
Set __module__ attribute for objects in jax.numpy
f652b6ad
[Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to t…
1471702a
Lower threefry as an out-of-line MLIR function on TPU.
23e9142d
Merge pull request #24913 from hawkinsp:threefry
d8085008
Add missing functions to jax.numpy type interface
5f942844
Add an example on logical operators to the tutorial.
5f1e3f56
Update XLA dependency to use revision
1780ff29
DOC: update main landing page style
81cdc882
Merge pull request #24918 from emilyfertig:emilyaf-logical-op-example
605c6051
Merge pull request #24914 from jakevdp:fix-pyi
1aa5de66
Consolidate material on PRNGs and add a short summary to Key Concepts.
225a2a5f
Merge pull request #24917 from emilyfertig:emilyaf-sharp-bits
efd23276
[sharding_in_types] Don't emit a wsc under full manual mode to avoid …
8525ef2b
Adds a flag to control proxy env checking.
609dfac2
Deduplicate constants in StableHLO lowering.
626aea01
Return a ndarray in shape_as_value if the shape is known to be constant.
1d519f4c
Update XLA dependency to use revision
7b9914d7
Use a direct StableHLO lowering for pow.
8a6c560b
Adds an env that can let users provide a custom version suffix for ja…
27bf80a5
Update XLA dependency to use revision
742cabc5
[AutoPGLE] Temporary disable pgle_test in the OSS.
ed250b89
Merge pull request #24930 from hawkinsp:dedup
f7ae0f99
Merge pull request #24933 from hawkinsp:pow
afdc7927
Add a GPU implementation of `lax.linalg.eig`.
ccb33170
Merge pull request #24932 from hawkinsp:gather
65f9c785
Add new CI script for running Bazel GPU presubmits
14187399
Merge pull request #24912 from jakevdp:jnp-module
05d66d7c
Merge pull request #24916 from jakevdp:update-lp
2de40e7d
Make logaddexp and logaddexp2 into ufuncs
e9864c69
Merge pull request #24903 from jakevdp:logsumexp
297a4e5e
Return SingleDeviceSharding instead of GSPMDShardings when there is o…
6fe7b171
fix typo in numpy/__init__.pyi
5bebd0f6
Merge pull request #24952 from jakevdp:fix-pyi
70b05f6c
[SDY] fix JAX layouts tests for Shardy.
0ed6eaeb
Disable some complex function accuracy tests that fail on Mac ARM.
461a2507
Filter custom dtypes by supported_dtypes in `_LazyDtypes`.
f3250516
[Pallas] Increase test coverage of pl.dot.
a60ef6e9
Merge pull request #24957 from hawkinsp:arm
16ed283f
Update XLA dependency to use revision
b3ca6c47
Merge pull request #23585 from apivovarov:float8_e4m3
91891cb6
Adds font fallbacks
d4316b57
Merge pull request #24958 from barnesjoseph:add-font-fallback
6952ddf4
Delete _normalized_spec from NamedSharding
e904c177
Fix a bug where mesh checking was not correct
2c68569a
[pallas] Minor simplifications to Pallas interpreter.
45c9c0a5
Update jax.scipy.special.gamma and gammasgn to return NaN for negativ…
c5e8ae80
Merge pull request #24945 from hawkinsp:gamma
4a9346e4
Merge pull request #24861 from yliu120:add_versions
58103e5a
[Mosaic TPU] Support relayout for mask vector
0fe77bc9
Merge pull request #24853 from yliu120:check_proxy_envs
12a43f1f
Implement lax.pad in Pallas.
d397dd96
[AutoPGLE] Use compile options to override debug options instead of X…
da50ad7e
Fix some typos
1458d3dd
fix(docs): typos in macro name
d912034c
Merge pull request #24942 from jeertmans:patch-1
9d3eda17
Merge pull request #24968 from nireekshak:testingbranch
6929a97c
Add missing version guard in GPU tests for jnp.poly.
3556a833
[Mosaic TPU] Add general tpu.vector_store and support masked store.
6c31efa3
Add alternate implementation of threefry as a pallas kernel.
c44f11d1
Add test utility for accessing jaxlib version tuple.
a59bbb7c
Add a new API jax.lax.split.
2c80d1af
Merge pull request #24970 from hawkinsp:split
2075b091
[pallas:mosaic_gpu] `copy_gmem_to_smem` no longer requires `barrier` …
1bf70fbb
[Mosaic] Add target core type parameter to tpu.sem_signal
0d36b0b4
Update XLA dependency to use revision
3161a284
Move JAX example to public XLA:CPU API
42fbd301
Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
525b646c
[Mosaic] Extend tpu.sem_signal with subcore_id
c04aec9d
Make deprecated jax.experimental.array_api module visibility internal…
8c71d1ad
Fix a bug where constant deduplication used an inappropriate inequality.
867a3618
[Mosaic] Add `tpu.log` verification on SC
6c291d67
represent `random.key_impl` of builtin RNGs by canonical string name
4bb81075
Add test_compute_on_host_shared_sharding in memories_test
4d60db17
Merge pull request #24593 from froystig:random-dtypes
ae46b756
[mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise.
1afb05e2
[pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcas…
14da7ebb
[pallas:mosaic_gpu] `copy_smem_to_gmem` now supports `wait_read_only`
c76e5fe9
[mosaic_gpu] Fixed `FragmentedArray` comparisons with literals
f442d40f
[mosaic_gpu] Handle older `jaxlib`s in the profiler module
04e4c69f
[pallas] Do not skip vmap tests on GPU when x64 is enabled
1df4b5f7
Update XLA dependency to use revision
a582df02
Mention python 3.13 in docs & package metadata
a4266b5e
Simplify handling of `DotAlgorithmPreset` output types.
1e9e85a3
Deprecate several private APIs in jax.lib
85e2969a
Merge pull request #25007 from jakevdp:deps
800add2a
Merge pull request #25005 from jakevdp:py313
439d34da
Fix KeyError recently introduced in cloud_tpu_init.py
62225926
[pallas mgpu] Lowering for while loops as long as they are secretly f…
8d84f283
Make a direct linearize trace.
d0f17c0c
Merge pull request #25004 from jax-ml:linearize-trace
eab9026c
Remove internal KeyArray alias
fee272e5
Don't psum over auto mesh dims in _unmentioned2.
2c9b917b
[pallas:mosaic_gpu] Avoid using multiple indexers in the parallel gri…
9584ee3b
Set __module__ attribute of jax.numpy.linalg APIs
621e39de
Merge pull request #25008 from skye:barrier
1a3e693a
Mention stackless in the release notes.
dfe27a16
Merge pull request #25013 from hawkinsp:relnotes
19b4996e
[sharding_in_types] Make flash_attention forward pass in TPU pallas w…
40fc6598
[Pallas TPU] Support masked store
9d2f62f8
Merge pull request #25011 from jakevdp:jnp-linalg-module
d219439d
[sharding_in_types] Add slice_p and squeeze_p sharding rule to make f…
9b941808
Update XLA dependency to use revision
6fe78042
[array api] use most recent version of array_api_tests
f749fca7
DOC: add examples for jax.lax.pad
2699e950
Merge pull request #25019 from jakevdp:lax-pad-doc
334bd4d0
jax.lax.pad: improve input validation
17825882
Adds Google Sans font
bf7f9aa8
Merge pull request #25020 from jakevdp:lax-pad-validation
f39392ea
[Pallas] Use Pallas cost estimator for flash attention.
1f6152d1
[sharding_in_types] Add `pad_p` support to sharding_in_types to handl…
840cf3f7
[Mosaic TPU] Add bound check for general vector store op.
869a5334
[sharding_in_types] Add `concatenate_p` support
6568713a
Reverts c04aec9d525dd2e767495e41b98e82dd79315f37
e72b4490
[pallas:mosaic_gpu] Pulled `delay_release` into `emit_pipeline`
f18df8f3
Integrate LLVM at llvm/llvm-project@33fcd6acc755
1bc9df42
[shape_poly] Adding shape polymorphism support for the state primitives.
0831e2e3
Run the TPU workflow on new self-hosted runners
7d7a0fa2
[JAX] Ignore xla_gpu_experimental_autotune_cache_mode when calculatin…
bf0150bb
Merge pull request #25015 from barnesjoseph:add-google-sans
73352677
Fix cron schedule to run past minute 0 every 2nd hour
1e6654a0
Merge pull request #25018 from jakevdp:update-array-api
b1b1ad62
[mgpu] Pointwise op can handle LHS splats.
1d2dc17e
[pallas] Add more test cases for Triton bitcast_convert_type lowering…
2178ed2f
Merge pull request #25034 from gnecula:poly_state
e707edea
Fix false positive `debug_nans` error caused by NaNs that are properl…
96c01299
[pallas:mosaic_gpu] `emit_pipeline` now correctly supports `BlockSpec…
1efef6bf
Merge pull request #25009 from jakevdp:keyarray
c1ae13b0
Remove unneeded dependency from rocm_plugin_extension.
f3e7e682
[Mosaic TPU] Fold sublane offset to indices when storing to untiled ref.
f899d515
Update XLA dependency to use revision
26443bbd
[Pallas] Add readme page for debugging tips.
344d0d99
Change signature of linearization rules.
170718c8
Merge pull request #25048 from jax-ml:linearization-rule-signature
3d79df24
[sharding_in_types] Add scan support to sharding_in_types. There are …
355589f3
charleshofer
requested a review
from
charleshofer
1 year ago
charleshofer
approved these changes on 2024-11-22
Merge branch 'rocm-main' into ci-upstream-sync-34_1
a450bb07
Longer timeout for doc render
846697f7
charleshofer
merged
3be7c1e6
into rocm-main
1 year ago
charleshofer
deleted the ci-upstream-sync-34_1 branch
1 year ago
Login to write a write a comment.
Login via GitHub
Reviewers
charleshofer
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub