transformers
0780b0b8 - fix: initialize BatchNorm2d buffers only when needed (#43520)

Commit
108 days ago
fix: initialize BatchNorm2d buffers only when needed (#43520) Commit 8dd9c999a6262d6ceb48f4a2da7acaccfa80e3bc introduced a regression by unconditionally reinitializing BatchNorm2d buffers (running_mean, running_var, num_batches_tracked) in the _init_weights() method. The problem: When loading pretrained timm models, the flow is: 1. timm.create_model(pretrained=True) correctly loads pretrained BatchNorm statistics 2. post_init() calls _init_weights() on all modules 3. _init_weights() overwrites the pretrained BatchNorm buffers with zeros/ones 4. Model produces incorrect results because normalization uses wrong statistics The error shows a mismatch at index (0, 178, 0, 0) with a difference of 80224.0 on CircleCI on AVX 512 CPUs The fix: we kip initialization if using pretrained backbone - buffers are already loaded from checkpoint.
Author
Parents
Loading