flax
37123d52 - Add functool.wraps() annotation to flax.nn.jit.

Commit
1 year ago
Add functool.wraps() annotation to flax.nn.jit. At the moment, all the jit names in a jaxpr show up as "jitted". functools.partial does not forward names. PiperOrigin-RevId: 648760671
Author
Committer
Parents
Loading