flax
4e83e091 - [linen] generalize transform caching

Commit
1 year ago
[linen] generalize transform caching * Renames `decorator_lift_transform_jit` to `decorator_lift_transform_cached` and `module_class_lift_transform_jit` to `module_class_lift_transform_cached`, and generalizes them to accept a `transform`. * Adds `lift_transfom_cached` to allow lifting any transform using the functions above. * Updates `lift.checkpoint` so it can be used with `lift_transfom_cached`. * Fixes potential bug in `lift.jit`. PiperOrigin-RevId: 654696043
Author
Cristian Garcia
Committer
Parents
Loading