flax
fc38f21e - Introduce a mode for a faster nn.scan mode that avoids extra jax retracing.

Commit
1 year ago
Introduce a mode for a faster nn.scan mode that avoids extra jax retracing. This adds a new keyword option to linen nn.scan `check_constancy_invariants` that defaults to True for the existing behavior. Setting it to False however avoids an extra jax trace to hoist scan loop constants out of the loop and to check for non-data-dependence of broadcast variables and body function outputs marked constant. The time savings from not running this extra trace and static check can be considerable when tracing and compiling larger models. PiperOrigin-RevId: 705869200
Author
Committer
Parents
Loading