flax
5ffd0006 - Add data_transform argument to nn.scan and preserve PartitionSpec attributes.

Commit
4 years ago
Add data_transform argument to nn.scan and preserve PartitionSpec attributes. pjit PartitionSpec Module attributes were being downcast to tuples during freezing. An extra data_transform kwarg to nn.scan is used to help fix an issue where XLA SPMD constraints don't propagate across XLA while loops. It allows us to use a workaround to re-apply SPMD constraints inside the scan body function. (Ultimately we hope to find a better upstream fix in JAX/XLA.) PiperOrigin-RevId: 383231061
Author
Committer
Parents
Loading