jax
e9b087d3 - Add `ffi_call` function with a similar signature to `pure_callback`.

Commit
1 year ago
Add `ffi_call` function with a similar signature to `pure_callback`. This could be useful for supporting the most common use cases for FFI custom calls. It has several benefits over using the `Primitive` based approach, but the biggest one (in my opinion) is that it doesn't require interacting with `mlir` at all. It does have the limitation that transforms would need to be registered using interfaces like `custom_vjp`, but many users of custom calls already do that. ~~The easiest to-do item (I think) is to implement batching using a `vectorized` parameter like `pure_callback`, but we could also think about more sophisticated vmapping interfaces in the future.~~ Done. The more difficult to-do is to think about how to support sharding, and we might actually want to expose an interface similar to the one from `custom_partitioning`. I have less experience with this part so I'll have to think some more about it, and feedback would be appreciated!
Author
dfm dfm
Committer
dfm dfm
Parents
Loading