added a models_dict to DataParallel
keys are devices
This helps w/ fairseq, where model itself isn't used directly inside
`train_loop_fn`, but rather their `Trainer` class, which has a loss and a
model attribute, that needs to be unique per device.
With this change,
* I can access the trainer inside `train_loop_fn`
* I can still instantiate `model_parallel` with a function that returns
the `model`, i.e. a `torch.nn.Module`
```
model_parallel = dp.DataParallel(
lambda: task.build_model(args), device_ids=DEVICES)
criteria = {
device: task.build_criterion(args) for device in DEVICES}
trainers = {
device: Trainer(args, task, model, criteria[device])
for device, model in model_parallel.models_dict.items()
}
....
def train_loop_fn(model, loader, device, context):
trainer = trainers[str(device)]
for i, samples in loader:
_log_output = trainer.train_step(samples) # applies model to input and more
xm.optimizer_step(trainer.optimizer)
tracker.add(len(samples) * BATCH_SIZE)
stats = fairseq_train.get_training_stats(trainer)
return tracker, stats
```