pytorch
0d66db1b - Implement last dim split_with_sizes for NT (forward only, non-SymInt-ified) (#97446)

Commit
3 years ago
Implement last dim split_with_sizes for NT (forward only, non-SymInt-ified) (#97446) This is needed for the HSTU model. Details: * ~~NT `chunk` now calls into NT `split_with_sizes` since the latter is more general~~ (removed; they're totally separate) * Throws for backward * Only operates over the last dim (`dim=-1`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97446 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading