jax
f2121a72 - [platform_dependent] Ensure that platform_dependent only lowers for intended platforms

Commit
293 days ago
[platform_dependent] Ensure that platform_dependent only lowers for intended platforms Fixes: #28594 Currently `lax.platform_dependent` allows specifying code that behaves differently when lowered on different platforms. However, this function operates in a confusing way, in that it will create a branch on the platform, but will lower all branches for the **current** lowering platforms. For example, in the following code: ``` lax.platform_dependent(x, cpu=for_cpu, tpu=for_tpu) ``` If we lower for CPU, we lower both `for_cpu` and `for_tpu` for CPU (!), but only the branch corresponding to `for_cpu` will actually run. This is a problem if, e.g., `for_tpu` does not have a lowering for CPU. We will get an error during lowering. Instead there should be no error during lowering, because that branch is not actually needed. We add a new test `test_platform_dependent_with_primitive_with_lowering_error` to demonstrate this. The solution implememented here is the Solution A from #28594: we add a `branches_platform` param to the `cond` primitive, which is propagated by all transformations. This param is used only for the conditionals arising from `lax.platform_dependendet`. During lowering we drop the branches corresponding to the platforms that are not interesting.
Author
Committer
Parents
Loading