jax
Rocm refactor pr
#19
Closed
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
156
Changes
View On
GitHub
Rocm refactor pr
#19
reza-amd
wants to merge 156 commits into
main
from
rocm_refactor_pr
Add constant initializer
86e8928e
Update compilation cache logging to log individual hashes as well as …
085e6586
Merge branch 'main' into compilelog
be12038b
Rename wrapper functions to always refer to the JAX api function.
010eb82a
improve pjit in/out_axis_resources pytree errors
d57990ec
tests: use lax.broadcast_shapes in place of custom logic
3e504895
Merge pull request #9526 from jakevdp:fix-util
fd354071
Recommend optax in jax.experimental_libraries.optimizers documentation.
b9b73ee6
Merge pull request #9469 from hawkinsp:doc
2ae10ea7
Respect __jax_array__ in jnp.ndarray operations
c069bfee
[remove units] make JaxprTrace.process_call not introduce units
7077ce2e
Merge pull request #9498 from mattjj:remove-units-2
06c40122
add flag for jaxpr colorful syntax highlighting
004bb684
Merge pull request #9456 from mattjj:jaxpr-pprint-color-flag-and-default
0566ea4c
[remove units] make JaxprTrace.process_call not introduce units
d59af33c
Merge pull request #9500 from mattjj:remove-units-3
51c7d3bb
shape_poly_test.py : remove duplicate word
fc4f47cd
Merge pull request #9539 from jakevdp:fix-jax-array
7af443aa
JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
4f6004a3
Merge pull request #9491 from LenaMartens:changelist/427247461
fb821e94
Checkify: handle named_call in process_call.
a4cacf57
Remove e Type annotation for jit an pmap as there are additional attr…
799ecfa9
Merge pull request #9558 from jblespiau:changelist/427986940
2c01312d
Add constant handler for tokens.
5a259925
JaxTestCase: set jax_numpy_rank_promotion='raise' by default
97512e9e
Merge pull request #9557 from LenaMartens:changelist/428497052
8b117c50
Merge pull request #9559 from hawkinsp:token
fb4934c2
Fix incorrect binary search comparison in lax.select_n lowering.
29c8a045
Merge pull request #9563 from hawkinsp:selectn
a457320b
Add translation rule for optimization barrier.
7ce911b8
Merge pull request #9562 from jakevdp:disable-rank-promotion
f229a703
Add a warning to random.choice to notify users of the ill-defined beh…
20e5090b
Implement the JAX transfer guard API
beaa00c4
Merge pull request #9564 from MichaelMarien:random-choice-docstring
d5694402
Add type promotion design doc
7381bbe8
Merge pull request #9407 from jakevdp:type-promotion-design
7204ac30
[sparse] Updates `bcoo_dot_general` cuSparse lowering rule by adding …
273ea626
Merge pull request #9485 from rsepassi:compilelog
924f7aab
Add multi-host utilities to JAX core. Adapted from https://github.com…
7613d2a5
Run all tests with jax_traceback_filtering=off.
0d9990e4
Add a partial_eval_jaxpr_custom_rule for xmap
c551beda
Limit the set of unspecified dims to those that are not explicitly co…
b75b0c04
Update the list of default CUDA capabilities used for wheel builds to…
2e0cfe8e
Merge pull request #9551 from RuffaloLavoisier:typo
0b920521
Try adding support for nesting with_sharding_constraint in MANUAL xmaps
a14ccb99
[JAX] Fix crash when applying jit() to a callable that is not weak-re…
b0b8f037
Merge pull request #9573 from hawkinsp:cuda
bf3c6581
type promotion design doc: minor typos
de360833
Checkify: fix check_error of nd-error.
b15c7f60
Fix nomenclature in type promotion doc
e82f232e
Merge pull request #9587 from jakevdp:fix-type-promotion
1da8b502
Merge branch 'main' into opt-barrier
662c4416
Merge pull request #9585 from jakevdp:typos
95c486a5
Merge pull request #9561 from pschuh:opt-barrier
c49fb9c2
fix: use boolean tf functions where possible instead of casting to in…
c1cbf786
fix: missing comma
88a1c049
[jax2tf] Fixed stale documentation about XLA metadata.
461b37b2
[jax2tf] Fixes shape polymorphism for jnp.take_along_axes
1928f6e6
Merge pull request #9595 from gnecula:tf_metadata
35082fce
Merge pull request #9597 from gnecula:tf_take_along_axis
c97fefcc
Set the aval inside _create_local_shards iteration. Since we are iter…
94aade03
Checkify: fix nd-error case when array only has 1 element.
758c7216
Merge pull request #9493 from mattjj:better-pjit-pytree-prefix-error
e1fd6304
Merge pull request #9602 from LenaMartens:changelist/429094321
052b5c36
Add cloudpickle as a test requirement.
901d459e
Merge pull request #9605 from hawkinsp:pickle
e25259e5
adds jax.scipy.schur
514d8883
adds jax.scipy.linalg.sqrtm
cb732323
Merge `mesh` and `Mesh`. Make `Mesh` a context manager + class so tha…
a83695a7
implement `jnp.expand_dims` and `jnp.stack` for PRNGKeyArrays
0f7904f8
Add some functions for spectral analysis.
e085370e
use singleton dims in broadcasting binop batchers
e0fb424d
`del self._old_env` so that you can use `with global_mesh` multiple t…
bd2a6a07
fix: linting errors
b657c23b
Merge pull request #9544 from SaturdayGenfo:adds-matrix-sqrt
15295a82
Fix uninitialized axis_env error when MLIR lowering is disabled
57f42320
Checkify: explicitly export public API, hide private symbols.
73f23705
refactor: rename, reorder, rewrite
6ccbc862
Merge pull request #9609 from froystig:prng-array-stack
032bfe09
err on repeated axes to `expand_dims`, as numpy does
35fab1a9
Clarify the NVidia driver version requirements.
d704c151
in tests, compare jnp operations on PRNGKeyArrays to the same on jnp …
88c6b84d
Merge pull request #9616 from hawkinsp:doc
83a50202
Xmap GDA integration. Non-contiguous mesh is allowed!
6bb58e6f
Merge pull request #9615 from froystig:jnp-expand-dims-error
5bb140f6
Merge pull request #9422 from yotarok:signal_stft
54a6e4da
Add jax.distributed and jax.dlpack to the docs.
3e5ecfe3
Merge pull request #9621 from hawkinsp:docs3
e545daa1
Fix Iree backend to for copy_to_device and executable results
50f4b580
Add deprecation warning to JaxTestCase and JaxTestLoader
da3aaa19
Merge pull request #9620 from jakevdp:deprecate-jax-test-case
d22129f1
Show all the device buffer shapes in the error message.
a430d0f2
Fix a typo (remove `== ss`).
b78ec009
fix: simplify reduce_min with reduce_all
20635e4a
[jax2tf] Added more links to the documentation
35763889
docs: spelling mistake fixed
f926d089
docs: fix explanation
e6d94e11
Merge pull request #9631 from gnecula:tf_doc
d123a10c
Turn execute_replicated into a class so we can access its fields.
607e7033
Merge pull request #9594 from oliverdutton:boolean_jax2tf
1baa59c6
Index update operators: add scatter_apply()
e13c847e
Allow unevenly partitioned sharding_constraints.
8cb16928
* Make _old_env thread local so that it can be used in multiple threads.
1486be7b
Checkify: address initial feedback.
2eeb683d
Revert back to adding aval on Device buffers inside local_shards and …
d1b6f5d9
Finish jax release
c161c628
Make resharding of GDA work if the shape is larger than what it was s…
3290dd3a
IREE's get_default_device_assignment should return List[Device]
23d5eb08
Merge pull request #9607 from jakevdp:scatter-apply
660616cf
Merge pull request #9617 from froystig:prng-array-tests
a65841f5
Add copy button to code snippets in documentation
6423741d
Always treat all mesh axes controlled by xmap as MANUAL
2641f061
Make local_to_global and global_to_local private.
afdb7b6f
Merge pull request #9636 from LenaMartens:changelist/429277776
97b1bd3b
Merge pull request #9613 from mattjj:vmap-binop-batchers-use-singletons
ab15db7d
fixing scan and other control flow
4b1d0a46
Fix tests and handle cond consts.
45d3ddda
Merge pull request #9663 from mattjj:checkify-scan-debug
7edd964c
Make the name `pjit` appear in xprof and mhlo module name. Before eve…
e96b91d4
Add separate mechanism for threading name stacks to the lowering
1b79caa6
Deprecate `maps.mesh` and replace it with `maps.Mesh`.
687a7630
Merge pull request #8395 from sharadmv:name-stack-mechanism
c0416945
Fix the gpu tests that were failing with Future warning
e2834d89
Rename `unmapped_local_out_avals` to `out_avals` since it can contain…
98e114da
Add .device() method to _DeviceArray
a9a827e2
Merge pull request #9052 from jpuigcerver:main
3948fde8
Remove jax.ops.index... functions.
f51a05a8
[JAX] Update comments and documents for ANN.
f2c17439
Merge pull request #9619 from hawkinsp:index2
cdd0c688
[JAX] Move ann.ann_recall back to tests.
8372b98c
Merge pull request #9658 from tornikeo:docs-copybutton
d2834b64
Remove duplicate changelog entry
51727033
Merge pull request #9684 from jakevdp:fix-changelog
ffcf4773
Merge pull request #9641 from rsuderman:FixDefaultDevice
6914bae3
Merge pull request #9623 from rsuderman:FixIreeJax
d5a1c64d
DOC: update sharp bits info about fori_loop differentiability
74e7fdfd
BUG: return numpy arrays for jnp.load() with unsupported dtypes
1b01865b
Merge pull request #9702 from jakevdp:fix-load
7ec04c83
Relax tolerances slightly for MKL.
bf59b7d8
Merge pull request #9706 from samuela:samuela/mkl-fix
d583f876
make `xla_executable` a property, consistent across executable types
d636e746
Merge pull request #9193 from froystig:xla-executable-attr
c66fbeb9
Remove an unnecessary condition.
25472c23
Merge pull request #9701 from jakevdp:fix-sharp-bits
0b9ae982
Merge pull request #9711 from jblespiau:changelist/431385993
043561ae
Bump actions/setup-python from 2 to 3
680c06dd
[XLA:CPU] Relax test tolerances for tests using XLA:CPU.
c339330b
Merge pull request #9718 from google:dependabot/github_actions/action…
92cb865b
jax.random.poisson: fix corner cases
2c2773a5
reza-amd
force pushed
from
18420dcc
to
806163bc
3 years ago
Add `block_until_ready` method to GDA
d0cc3395
Merge pull request #9721 from jakevdp:poisson-nan
c7508d1f
Handle jaxpr constants correctly in MLIR lowering of conditional bran…
cffe9978
Factor-out pieces of lax_numpy.py
d09d7b8d
Merge pull request #9724 from jakevdp:refactor-lax-numpy
5b309880
Make the global_mesh that GDA has public
11664f8a
Update JAX to use new math libraries in ROCm-5.0.
a0d9d81f
reza-amd
closed this
3 years ago
reza-amd
reopened this
3 years ago
reza-amd
force pushed
from
51157ee7
to
a0d9d81f
3 years ago
reza-amd
closed this
3 years ago
gulsumgudukbay
deleted the rocm_refactor_pr branch
214 days ago
Login to write a write a comment.
Login via GitHub
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub