Add data_transform argument to nn.scan and preserve PartitionSpec attributes.
pjit PartitionSpec Module attributes were being downcast to tuples during freezing.
An extra data_transform kwarg to nn.scan is used to help fix an issue where XLA SPMD
constraints don't propagate across XLA while loops. It allows us to use a
workaround to re-apply SPMD constraints inside the scan body function.
(Ultimately we hope to find a better upstream fix in JAX/XLA.)
PiperOrigin-RevId: 383231061