flax
28b5b425 - Dramatically speed up sampling compilation time

Commit
304 days ago
Dramatically speed up sampling compilation time On DD:2x2 hardware, this reduces compilation time from roughly 120s to 6s We separate out the parameters from the model graph so that the parameters are passed as a parameter to the jitted function, rather than being kept static. PiperOrigin-RevId: 735316560
Author
Committer
Parents
Loading