pytorch
ed0091f8 - fix view_copy kernel striding check logic (#81553)

Commit
2 years ago
fix view_copy kernel striding check logic (#81553) The composite kernel for `view_copy` that we generate is special-cased a bit for efficiency to avoid having to do extra clones in some cases. That logic was slightly wrong though, and is fixed here (it needs to mirror the logic in `reshape()`). It manifested as a debug assert firing for Lazy Tensor, which I confirmed no longer fires when running this script: ``` # ran with "python test_ltc_only_torch.py --device=lazy --sync=1 --nvtx=1" import torch import torch._lazy from torch._lazy.ts_backend import init as init_ts_backend init_ts_backend() torch.manual_seed(42) from transformers import BertForSequenceClassification def parse_args(): import argparse parser = argparse.ArgumentParser(description='') parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--sync', type=bool, default=False) parser.add_argument('--nvtx', type=bool, default=False) return parser.parse_args() args = parse_args() device = args.device model = BertForSequenceClassification.from_pretrained('bert-base-uncased', return_dict=True) from transformers import AdamW from transformers import BertTokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') text_batch = ["I love Pixar.", "I don't care for Pixar."] encoding = tokenizer(text_batch, return_tensors='pt', padding=True, truncation=True) input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) model = model.to(device) model.train() no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5) labels = torch.tensor([1,0]).unsqueeze(0).to(device) for _ in range(6): torch.cuda.nvtx.range_push(f'Iter{_}') torch.cuda.nvtx.range_push('F') outputs = model(input_ids, attention_mask=attention_mask, labels=labels) if args.sync: torch._lazy.mark_step() torch._lazy.wait_device_ops() torch.cuda.nvtx.range_pop() loss = outputs.loss torch.cuda.nvtx.range_push('B') optimizer.zero_grad() loss.backward() if args.sync: torch._lazy.mark_step() torch._lazy.wait_device_ops() torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push('O') optimizer.step() if args.sync: torch._lazy.mark_step() torch._lazy.wait_device_ops() torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop() torch._lazy.mark_step() torch._lazy.wait_device_ops() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/81553 Approved by: https://github.com/ezyang
Author
Committer
Parents
Loading