jax
9adb67fd - PR #30247: Add Shardy rules for fused_attention_stablehlo.

Commit
214 days ago
PR #30247: Add Shardy rules for fused_attention_stablehlo. Imported from GitHub PR https://github.com/jax-ml/jax/pull/30247 I tried to match what the gspmd sharding inference does, without looking into the op itself too deeply. If Shardy has additional requirements, please let me know. Only tested with this unit test: ``` bazel test tests:fused_attention_stablehlo_test_gpu --test_env=JAX_USE_SHARDY_PARTITIONER=1 ``` Copybara import of the project: -- 4757bd8d7699a6db5c2461e4a84ce63ed121ff9b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add Shardy rules for fused_attention_stablehlo. I tried to match what the gspmd sharding inference does, without looking into the op itself too deeply. If Shardy has additional requirements, please let me know. Only tested with this unit test: ``` bazel test tests:fused_attention_stablehlo_test_gpu --test_env=JAX_USE_SHARDY_PARTITIONER=1 ``` Merging this change closes #30247 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/30247 from jreiffers:main 4757bd8d7699a6db5c2461e4a84ce63ed121ff9b PiperOrigin-RevId: 784595275
Author
Parents
Loading