flax
c2e85837 - Allow Flax lifted transforms to work on partially applied Modules.

Commit
4 years ago
Allow Flax lifted transforms to work on partially applied Modules. Often people like applying functools.partial to a nn.Module constructor to specialize it. At the moment these can't be transformed by nn.vmap, nn.scan, etc. throwing an error. This PR fixes this issue. PiperOrigin-RevId: 367489880
Author
Committer
Parents
Loading