transformers
d4bd33cc - Register ModelOutput subclasses as supported torch.utils._pytree nodes (#25358)

Commit
2 years ago
Register ModelOutput subclasses as supported torch.utils._pytree nodes (#25358) * Register ModelOutput subclasses as supported torch.utils._pytree nodes Fixes #25357 where DDP with static_graph=True does not sync gradients when calling backward() over tensors contained in ModelOutput subclasses * Add test for torch pytree ModelOutput serialization and deserialization
Author
Matthew Hoffman
Parents
Loading