pytorch
82ab167c - [NNC] Fix masking for all block and thread dimensions in CudaCodeGen (#44733)

Commit
4 years ago
[NNC] Fix masking for all block and thread dimensions in CudaCodeGen (#44733) Summary: Unifies a number of partial solutions to the thread and block dimension extent masking, including the NoThreadIdxWriter and my last fix https://github.com/pytorch/pytorch/issues/44325. The NoThreadIdxWriter is gone in favour of tracking the current loop extents and masking any statements that have a lower rank than the launch parameters in any Block or Thread dimension, which handles both the "no" and "smaller" axis binding cases. For example it will transform the following: ``` for i in 0..10 // blockIdx.x for j in 0..10 // threadIdx.x do thing(i, j); for k in 0..5 // threadIdx.x do other thing(i, k); ``` Into: ``` do thing(blockIdx.x, threadIdx.x); if (threadIdx.x < 5) { do other thing(blockIdx.x, threadIdx.x); } ``` And handle the case where statements are not bound by any axis, eg. ``` do outer thing; for i in 0..10 // blockIdx.x for j in 0..10 // threadIdx.x do thing(i, j); do other thing(i); ``` will become: ``` if (blockIdx.x < 1) { if (threadIdx.x < 1) { do outer thing; } } syncthreads(); do thing(blockIdx.x, threadIdx.x); syncthreads(); if (threadIdx.x < 1) { do other thing(blockIdx.x); } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/44733 Reviewed By: mruberry Differential Revision: D23736878 Pulled By: nickgg fbshipit-source-id: 52d08626ae8043d53eb937843466874d479a6768
Author
Parents
Loading