jax
4a43a075 - replace fusible_p with Fusible(VJPHiPrimitive)

Commit
14 days ago
replace fusible_p with Fusible(VJPHiPrimitive) `fusible_p` is replaced with a new `Fusible` primitive type that inherits from `VJPHiPrimitive`, and implements the `expand`, `vjp_fwd`, `vjp_bwd_retval`, and `batch` methods. The physicalize rule for `fusible_p` is removed in favor of a rule for `hijax.call_hi_primitive_p` which is what `VJPHiPrimitive` emits. `fuse_jaxpr` is updated to search for and use `call_hi_primitive_p` primitives accordingly. Reverts e7884f87d11d0beb9a47051b94c2f5c3ad763b5f PiperOrigin-RevId: 903555163
Parents
Loading