pytorch
09f79e94 - support nested_tensor * scalar (#80284)

Commit
2 years ago
support nested_tensor * scalar (#80284) In transformer, the scale step in attention has a `nested_tensor / scalar` operation. There are two ways to support that: 1. directly support `nested_tensor / scalar`: * pro: straightforward, good UX * con: is dispatching `mul(nested tensor, regular tensor)` a good practice? 2. let user manually convert `scalar` to `nested_scalar = torch.nested_tensor([broadcast_scalar])` * pro: dispatcher only has to deal with `mul(nested tensor, nested tensor)` * con: confusing manual conversions, bad UX Pull Request resolved: https://github.com/pytorch/pytorch/pull/80284 Approved by: https://github.com/cpuhrsch
Author
Yifan Shen
Committer
Parents
Loading