flax
6265b48d - change defaults for Dropout and BatchNorm

Commit
6 days ago
change defaults for Dropout and BatchNorm Changes `Dropout.deterministic` and `BatchNorm.use_running_average` to be None by default, use now has to explicitely provide them by either: 1. Passing them to the constructor e.g: self.bn = nnx.BatchNorm(..., use_running_average=False) 2. Passing them to __call__: self.dropout(x, deterministic=False) 3. Using `nnx.view` to create a view of the model with specific values: train_model = nnx.view(model, detereministic=False, use_running_average=False) PiperOrigin-RevId: 877557940
Author
Cristian Garcia
Committer
Parents
Loading