jax
fcf5115f - [Pallas Fuser] Add output_fusion_mask support

Commit
290 days ago
[Pallas Fuser] Add output_fusion_mask support Currently, the fusion API assumes by default that all of the outputs of a @fuse-decorated function are computed jointly in one big output fusion. For example, in the following snippet ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g(z1, z2) ``` it assumes that `g` is a single function that operates on z1 and z2 jointly. However, in practice, the fusable may want two separate output fusions: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return g1(z1), g2(z2) ``` This is a special case of the general function but the fusable may not be materializing z1 and z2 at the same time so may not be able to compute this efficiently with a single function g. By decorating a fusable with an output fusion prefix (in the above example `(True, True)`), the fusable will now be given a pair of functions `g1` and `g2` if the output fusion is "separable". For example, we'd error for the following example: ```python @fuse def f(x, y): z1, z2 = fusable_f(x, y) return z1 + z2 ``` because z1 and z2 interact with each other in the output fusion. The rationale for providing a PyTree prefix (as opposed to a more general mechanism) is that the fusable can group its outputs into subtrees that it can identify with the output prefix. This does restrict the types of output groups that are possible (outputs must be part of the same shared subtree, as opposed to arbitrarily scattered throughput the output pytree), but this is an okay restriction because the fusable author is responsible for the grouping and can always construct it that way. PiperOrigin-RevId: 744784770
References
Author
Parents
Loading