jax
Rocm refactor pr
#19
Closed

Rocm refactor pr #19

reza-amd wants to merge 156 commits into main from rocm_refactor_pr
reza-amd
jpuigcerver Add constant initializer
86e8928e
rsepassi Update compilation cache logging to log individual hashes as well as …
085e6586
rsepassi Merge branch 'main' into compilelog
be12038b
LenaMartens Rename wrapper functions to always refer to the JAX api function.
010eb82a
mattjj improve pjit in/out_axis_resources pytree errors
d57990ec
jakevdp tests: use lax.broadcast_shapes in place of custom logic
3e504895
Merge pull request #9526 from jakevdp:fix-util
fd354071
hawkinsp Recommend optax in jax.experimental_libraries.optimizers documentation.
b9b73ee6
Merge pull request #9469 from hawkinsp:doc
2ae10ea7
jakevdp Respect __jax_array__ in jnp.ndarray operations
c069bfee
mattjj [remove units] make JaxprTrace.process_call not introduce units
7077ce2e
Merge pull request #9498 from mattjj:remove-units-2
06c40122
mattjj add flag for jaxpr colorful syntax highlighting
004bb684
Merge pull request #9456 from mattjj:jaxpr-pprint-color-flag-and-default
0566ea4c
mattjj [remove units] make JaxprTrace.process_call not introduce units
d59af33c
Merge pull request #9500 from mattjj:remove-units-3
51c7d3bb
RuffaloLavoisier shape_poly_test.py : remove duplicate word
fc4f47cd
Merge pull request #9539 from jakevdp:fix-jax-array
7af443aa
JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
4f6004a3
Merge pull request #9491 from LenaMartens:changelist/427247461
fb821e94
LenaMartens Checkify: handle named_call in process_call.
a4cacf57
jblespiau Remove e Type annotation for jit an pmap as there are additional attr…
799ecfa9
Merge pull request #9558 from jblespiau:changelist/427986940
2c01312d
hawkinsp Add constant handler for tokens.
5a259925
jakevdp JaxTestCase: set jax_numpy_rank_promotion='raise' by default
97512e9e
Merge pull request #9557 from LenaMartens:changelist/428497052
8b117c50
Merge pull request #9559 from hawkinsp:token
fb4934c2
hawkinsp Fix incorrect binary search comparison in lax.select_n lowering.
29c8a045
Merge pull request #9563 from hawkinsp:selectn
a457320b
pschuh Add translation rule for optimization barrier.
7ce911b8
Merge pull request #9562 from jakevdp:disable-rank-promotion
f229a703
MichaelMarien Add a warning to random.choice to notify users of the ill-defined beh…
20e5090b
hyeontaek Implement the JAX transfer guard API
beaa00c4
Merge pull request #9564 from MichaelMarien:random-choice-docstring
d5694402
jakevdp Add type promotion design doc
7381bbe8
Merge pull request #9407 from jakevdp:type-promotion-design
7204ac30
tlu7 [sparse] Updates `bcoo_dot_general` cuSparse lowering rule by adding …
273ea626
Merge pull request #9485 from rsepassi:compilelog
924f7aab
yashk2810 Add multi-host utilities to JAX core. Adapted from https://github.com…
7613d2a5
LenaMartens Run all tests with jax_traceback_filtering=off.
0d9990e4
apaszke Add a partial_eval_jaxpr_custom_rule for xmap
c551beda
apaszke Limit the set of unspecified dims to those that are not explicitly co…
b75b0c04
hawkinsp Update the list of default CUDA capabilities used for wheel builds to…
2e0cfe8e
Merge pull request #9551 from RuffaloLavoisier:typo
0b920521
apaszke Try adding support for nesting with_sharding_constraint in MANUAL xmaps
a14ccb99
hawkinsp [JAX] Fix crash when applying jit() to a callable that is not weak-re…
b0b8f037
Merge pull request #9573 from hawkinsp:cuda
bf3c6581
jakevdp type promotion design doc: minor typos
de360833
LenaMartens Checkify: fix check_error of nd-error.
b15c7f60
jakevdp Fix nomenclature in type promotion doc
e82f232e
Merge pull request #9587 from jakevdp:fix-type-promotion
1da8b502
pschuh Merge branch 'main' into opt-barrier
662c4416
Merge pull request #9585 from jakevdp:typos
95c486a5
Merge pull request #9561 from pschuh:opt-barrier
c49fb9c2
oliverdutton fix: use boolean tf functions where possible instead of casting to in…
c1cbf786
oliverdutton fix: missing comma
88a1c049
gnecula [jax2tf] Fixed stale documentation about XLA metadata.
461b37b2
gnecula [jax2tf] Fixes shape polymorphism for jnp.take_along_axes
1928f6e6
Merge pull request #9595 from gnecula:tf_metadata
35082fce
Merge pull request #9597 from gnecula:tf_take_along_axis
c97fefcc
yashk2810 Set the aval inside _create_local_shards iteration. Since we are iter…
94aade03
LenaMartens Checkify: fix nd-error case when array only has 1 element.
758c7216
Merge pull request #9493 from mattjj:better-pjit-pytree-prefix-error
e1fd6304
Merge pull request #9602 from LenaMartens:changelist/429094321
052b5c36
hawkinsp Add cloudpickle as a test requirement.
901d459e
Merge pull request #9605 from hawkinsp:pickle
e25259e5
SaturdayGenfo adds jax.scipy.schur
514d8883
SaturdayGenfo adds jax.scipy.linalg.sqrtm
cb732323
yashk2810 Merge `mesh` and `Mesh`. Make `Mesh` a context manager + class so tha…
a83695a7
froystig implement `jnp.expand_dims` and `jnp.stack` for PRNGKeyArrays
0f7904f8
yotarok Add some functions for spectral analysis.
e085370e
mattjj use singleton dims in broadcasting binop batchers
e0fb424d
yashk2810 `del self._old_env` so that you can use `with global_mesh` multiple t…
bd2a6a07
oliverdutton fix: linting errors
b657c23b
Merge pull request #9544 from SaturdayGenfo:adds-matrix-sqrt
15295a82
apaszke Fix uninitialized axis_env error when MLIR lowering is disabled
57f42320
LenaMartens Checkify: explicitly export public API, hide private symbols.
73f23705
oliverdutton refactor: rename, reorder, rewrite
6ccbc862
Merge pull request #9609 from froystig:prng-array-stack
032bfe09
froystig err on repeated axes to `expand_dims`, as numpy does
35fab1a9
hawkinsp Clarify the NVidia driver version requirements.
d704c151
froystig in tests, compare jnp operations on PRNGKeyArrays to the same on jnp …
88c6b84d
Merge pull request #9616 from hawkinsp:doc
83a50202
yashk2810 Xmap GDA integration. Non-contiguous mesh is allowed!
6bb58e6f
Merge pull request #9615 from froystig:jnp-expand-dims-error
5bb140f6
Merge pull request #9422 from yotarok:signal_stft
54a6e4da
hawkinsp Add jax.distributed and jax.dlpack to the docs.
3e5ecfe3
Merge pull request #9621 from hawkinsp:docs3
e545daa1
rsuderman Fix Iree backend to for copy_to_device and executable results
50f4b580
jakevdp Add deprecation warning to JaxTestCase and JaxTestLoader
da3aaa19
Merge pull request #9620 from jakevdp:deprecate-jax-test-case
d22129f1
yashk2810 Show all the device buffer shapes in the error message.
a430d0f2
yashk2810 Fix a typo (remove `== ss`).
b78ec009
oliverdutton fix: simplify reduce_min with reduce_all
20635e4a
gnecula [jax2tf] Added more links to the documentation
35763889
oliverdutton docs: spelling mistake fixed
f926d089
oliverdutton docs: fix explanation
e6d94e11
Merge pull request #9631 from gnecula:tf_doc
d123a10c
jblespiau Turn execute_replicated into a class so we can access its fields.
607e7033
Merge pull request #9594 from oliverdutton:boolean_jax2tf
1baa59c6
jakevdp Index update operators: add scatter_apply()
e13c847e
reza-amd
ukoxyz Allow unevenly partitioned sharding_constraints.
8cb16928
yashk2810 * Make _old_env thread local so that it can be used in multiple threads.
1486be7b
LenaMartens Checkify: address initial feedback.
2eeb683d
yashk2810 Revert back to adding aval on Device buffers inside local_shards and …
d1b6f5d9
reza-amd
yashk2810 Finish jax release
c161c628
yashk2810 Make resharding of GDA work if the shape is larger than what it was s…
3290dd3a
rsuderman IREE's get_default_device_assignment should return List[Device]
23d5eb08
Merge pull request #9607 from jakevdp:scatter-apply
660616cf
Merge pull request #9617 from froystig:prng-array-tests
a65841f5
tornikeo Add copy button to code snippets in documentation
6423741d
apaszke Always treat all mesh axes controlled by xmap as MANUAL
2641f061
yashk2810 Make local_to_global and global_to_local private.
afdb7b6f
Merge pull request #9636 from LenaMartens:changelist/429277776
97b1bd3b
Merge pull request #9613 from mattjj:vmap-binop-batchers-use-singletons
ab15db7d
mattjj fixing scan and other control flow
4b1d0a46
LenaMartens Fix tests and handle cond consts.
45d3ddda
Merge pull request #9663 from mattjj:checkify-scan-debug
7edd964c
yashk2810 Make the name `pjit` appear in xprof and mhlo module name. Before eve…
e96b91d4
sharadmv Add separate mechanism for threading name stacks to the lowering
1b79caa6
yashk2810 Deprecate `maps.mesh` and replace it with `maps.Mesh`.
687a7630
Merge pull request #8395 from sharadmv:name-stack-mechanism
c0416945
yashk2810 Fix the gpu tests that were failing with Future warning
e2834d89
yashk2810 Rename `unmapped_local_out_avals` to `out_avals` since it can contain…
98e114da
yashk2810 Add .device() method to _DeviceArray
a9a827e2
Merge pull request #9052 from jpuigcerver:main
3948fde8
hawkinsp Remove jax.ops.index... functions.
f51a05a8
[JAX] Update comments and documents for ANN.
f2c17439
Merge pull request #9619 from hawkinsp:index2
cdd0c688
[JAX] Move ann.ann_recall back to tests.
8372b98c
Merge pull request #9658 from tornikeo:docs-copybutton
d2834b64
jakevdp Remove duplicate changelog entry
51727033
Merge pull request #9684 from jakevdp:fix-changelog
ffcf4773
Merge pull request #9641 from rsuderman:FixDefaultDevice
6914bae3
Merge pull request #9623 from rsuderman:FixIreeJax
d5a1c64d
jakevdp DOC: update sharp bits info about fori_loop differentiability
74e7fdfd
jakevdp BUG: return numpy arrays for jnp.load() with unsupported dtypes
1b01865b
Merge pull request #9702 from jakevdp:fix-load
7ec04c83
samuela Relax tolerances slightly for MKL.
bf59b7d8
Merge pull request #9706 from samuela:samuela/mkl-fix
d583f876
froystig make `xla_executable` a property, consistent across executable types
d636e746
Merge pull request #9193 from froystig:xla-executable-attr
c66fbeb9
jblespiau Remove an unnecessary condition.
25472c23
Merge pull request #9701 from jakevdp:fix-sharp-bits
0b9ae982
Merge pull request #9711 from jblespiau:changelist/431385993
043561ae
dependabot[bot] Bump actions/setup-python from 2 to 3
680c06dd
hawkinsp [XLA:CPU] Relax test tolerances for tests using XLA:CPU.
c339330b
reza-amd
Merge pull request #9718 from google:dependabot/github_actions/action…
92cb865b
jakevdp jax.random.poisson: fix corner cases
2c2773a5
reza-amd reza-amd force pushed from 18420dcc to 806163bc 3 years ago
yashk2810 Add `block_until_ready` method to GDA
d0cc3395
Merge pull request #9721 from jakevdp:poisson-nan
c7508d1f
hawkinsp Handle jaxpr constants correctly in MLIR lowering of conditional bran…
cffe9978
jakevdp Factor-out pieces of lax_numpy.py
d09d7b8d
reza-amd
Merge pull request #9724 from jakevdp:refactor-lax-numpy
5b309880
yashk2810 Make the global_mesh that GDA has public
11664f8a
reza-amd Update JAX to use new math libraries in ROCm-5.0.
a0d9d81f
reza-amd reza-amd closed this 3 years ago
reza-amd reza-amd reopened this 3 years ago
reza-amd reza-amd force pushed from 51157ee7 to a0d9d81f 3 years ago
reza-amd reza-amd closed this 3 years ago
gulsumgudukbay gulsumgudukbay deleted the rocm_refactor_pr branch 214 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
No one assigned
Labels
Milestone