The pattern used in this change (`obj = object.__new__(data_clz)`) is not universally applicable to all Flax subclasses.
Here is a minimal repro:
```
from flax import struct
import jax
from collections import OrderedDict
@struct.dataclass
class Foo(OrderedDict):
x: int
foo = Foo(1)
print(jax.tree_map(lambda x: x + 1, foo))
```
PiperOrigin-RevId: 526744687