xla
9a56a152 - added a models_dict to DataParallel

Commit
6 years ago
added a models_dict to DataParallel keys are devices This helps w/ fairseq, where model itself isn't used directly inside `train_loop_fn`, but rather their `Trainer` class, which has a loss and a model attribute, that needs to be unique per device. With this change, * I can access the trainer inside `train_loop_fn` * I can still instantiate `model_parallel` with a function that returns the `model`, i.e. a `torch.nn.Module` ``` model_parallel = dp.DataParallel( lambda: task.build_model(args), device_ids=DEVICES) criteria = { device: task.build_criterion(args) for device in DEVICES} trainers = { device: Trainer(args, task, model, criteria[device]) for device, model in model_parallel.models_dict.items() } .... def train_loop_fn(model, loader, device, context): trainer = trainers[str(device)] for i, samples in loader: _log_output = trainer.train_step(samples) # applies model to input and more xm.optimizer_step(trainer.optimizer) tracker.add(len(samples) * BATCH_SIZE) stats = fairseq_train.get_training_stats(trainer) return tracker, stats ```
Author
Parents
Loading