jax
96b7dbab - [JAX] Implement an initial object API for colocated Python

Commit
1 year ago
[JAX] Implement an initial object API for colocated Python Colocated Python adds `colocated_python_class`. This API wraps a user-defined class for automatic remoting of object construction/destruction and method calls: * An object will be initialized on the backend. At least for now, initialization is deferred until the first method is called; at this point, colocated Python knows what devices the objects should be accessible and thus it can construct the object(s). * When an object method is called, the method call runs as a colocated Python function call on the backend. * When the object is destroyed (either by reaching a zero reference count or through Python GC), destruction also runs as a colocated Python function call and destroys all objects from the backend. This change provides an intial API implementation. Main limitations are as follows: * The methods of a colocated Python class does not support specialization. Calling it requires at least one argument. * Colocated Python objects cannot reference or interact with each other on the controller or on the colocated Python backend. These limitations will be lifted as the object API implementation is improved. PiperOrigin-RevId: 729629265
Author
Parents
Loading