Allow Flax lifted transforms to work on partially applied Modules.
Often people like applying functools.partial to a nn.Module constructor to
specialize it. At the moment these can't be transformed by nn.vmap, nn.scan, etc.
throwing an error. This PR fixes this issue.
PiperOrigin-RevId: 367489880