pytorch
52a2b612 - Fix fetch function which breaks user code (#85099)

Commit
3 years ago
Fix fetch function which breaks user code (#85099) The [fastNLP](https://github.com/fastnlp/fastNLP/blob/v0.6.0/fastNLP/core/batch.py#L51) model uses DataSetGetter to fetch data from the dataset. The following code breaks because of https://github.com/pytorch/pytorch/pull/84301: ``` from fastNLP.io.pipe.qa import CMRC2018BertPipe input_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data", "cmrc2018-sim") data_bundle = CMRC2018BertPipe().process_from_file(paths=input_dir) data_bundle.rename_field('chars', 'words') data_bundle.get_dataset('dev') dataset = DataSetGetter(dataset, as_numpy) dataiter = torch.utils.data.DataLoader(dataset=dataset) for batch in dataiter: # data-processing... ``` This is because for the `DataSetGetter` class, the following condition holds: ``` # hasattr(dataset_getter, '__getitems__') == True # dataset_getter.__getitems__ == None ``` This PR adds an additional check to make sure `__getitems__` is only called when it is not None. This error was found by the torchbench nightly CI, original error stack trace: ``` ERROR: test_fastNLP_Bert_train_cuda (__main__.TestBenchmark) ---------------------------------------------------------------------- components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last): File "/home/circleci/project/components/_impl/workers/subprocess_rpc.py", line 470, in _run_block exec( # noqa: P204 File "<subprocess-worker>", line 35, in <module> File "<subprocess-worker>", line 12, in _run_in_worker_f File "/home/circleci/project/torchbenchmark/util/model.py", line 16, in __call__ obj = type.__call__(cls, *args, **kwargs) File "/home/circleci/project/torchbenchmark/models/fastNLP_Bert/__init__.py", line 93, in __init__ self.example_inputs = self._prefetch(example_inputs) File "/home/circleci/project/torchbenchmark/models/fastNLP_Bert/__init__.py", line 133, in _prefetch for batch_x, batch_y in example_inputs: File "/home/circleci/miniconda3/lib/python3.8/site-packages/fastNLP/core/batch.py", line 266, in __iter__ for indices, batch_x, batch_y in self.dataiter: File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 681, in __next__ data = self._next_data() File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 719, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 56, in fetch data = self.dataset.__getitems__(possibly_batched_index) TypeError: 'NoneType' object is not callable ``` Full error log: https://app.circleci.com/pipelines/github/pytorch/benchmark/5143/workflows/0676f36d-0ab4-42bd-adb4-90e6b0df76d1/jobs/5293 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85099 Approved by: https://github.com/ejguan
Author
Committer
Parents
Loading