Prioritize user-defined `train()` function over the staged `forward()` (#2174)
Summary:
It gives more user-friendly error messages upon unimplemented train tests.
Fixes https://github.com/pytorch/benchmark/issues/2166
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2174
Test Plan:
```
$ python -u run.py -d cuda -t train --bs 4 --metrics None hf_Whisper
/home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
/home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
/home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
Running train method from hf_Whisper on cuda in eager mode with input batch size 4 and precision fp32.
Traceback (most recent call last):
File "/workspace/benchmark/run.py", line 623, in <module>
main() # pragma: no cover
^^^^^^
File "/workspace/benchmark/run.py", line 593, in main
run_one_step(
File "/workspace/benchmark/run.py", line 173, in run_one_step
func()
File "/workspace/benchmark/torchbenchmark/util/model.py", line 315, in invoke
self.train()
File "/workspace/benchmark/torchbenchmark/models/hf_Whisper/__init__.py", line 20, in train
raise NotImplementedError("Training is not implemented.")
NotImplementedError: Training is not implemented.
```
Reviewed By: aaronenyeshi
Differential Revision: D54012510
Pulled By: xuzhao9
fbshipit-source-id: bb27bd5adb0bcd778c2c58db7ef5a7b8cc9b2c20