Eliminate Named tensor warnings in XNNPACK and QNNPACK
XNNPACK and QNNPACK propagate dimension names during allocation, but
they never create named tensors from scratch. However, they were
generating warnings ("Warning: Named tensors and all their associated
APIs are an experimental feature...") from the call
`namedinference::propagate_names_if_present_and_nonempty(t1, t2.names())`
(because `t2.names()` returns a non-empty DimnameList with wildcards.)
Introduce propagate_names_if_present_and_nonempty, which takes an
`optional<DimnameList>` from `t2.opt_names()`, and use it from QNNPACK
and XNNPACK to eliminate this warning. Another option would be to make
`propagate_names_if_nonempty` take an optional, which would be
more-or-less backward compatible since
`propagate_names_if_nonempty(t1, t2.names())` would still be valid due
to implicit conversion to optional. I chose not to do this just because
it seemed riskier.
Test Plan:
For QNNPACK, just ran the NNAPI test, which was previously throwing a
warning. For XNNPACK, its uses are gated on C10_MOBILE, so I needed to
export a model.
```python
import torch
import torch.utils.bundled_inputs
import torch.utils.mobile_optimizer
import torch.nn.functional as F
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.convt = torch.nn.ConvTranspose2d(1, 1, 1)
self.lin = torch.nn.Linear(1, 1)
def forward(self, t):
t = self.conv(t)
t = self.convt(t)
t = F.hardswish(t, False)
t = F.hardswish(t[...,::2], True)
t = F.adaptive_avg_pool2d(t, (1, 1))
t = torch.broadcast_to(t, (1, 4, 1, 1))
t = t.contiguous(memory_format=torch.channels_last)
t = F.channel_shuffle(t, 2)
t = F.max_pool2d(t.reshape(1, 1, 2, 2), (2, 2))
t = self.lin(t.reshape(1, 1).contiguous())
return t
arg = torch.zeros(1,1,1,3)
mtr = torch.jit.trace(MyModel().eval(), arg)
mom = torch.utils.mobile_optimizer.optimize_for_mobile(mtr)
mbi = torch.utils.bundled_inputs.bundle_inputs(mom, [((arg,))])
torch.jit.save(mbi, "/tmp/model.ptj")
```
```bash
nice env CMAKE_CXX_COMPILER_LAUNCHER=ccache CMAKE_C_COMPILER_LAUNCHER=ccache ./scripts/build_mobile.sh -DBUILD_BINARY=1
/build_mobile/bin/speed_benchmark_torch --model=/tmp/model.ptj --use_bundled_input=0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77762
Approved by: https://github.com/zou3519