[inductor] remove RBLOCK from persistent reduction kernel's parameter list (#98653)
This PR resolves comments https://github.com/pytorch/pytorch/pull/97203#discussion_r1160491318 . Send a separate PR since it's easier to test and make sure there is no perf impact.
Tests:
1. python test/inductor/test_torchinductor.py
2. run `python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only hf_Bert --disable-cudagraphs --training` before and after the change to make sure the perf change is neutral.
Now a persistent reduction kernel in hf_Bert looks like:
```
@persistent_reduction(
size_hints=[4096, 1024],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*i64', 2: '*fp16', 3: '*i64', 4: '*fp16', 5: '*i64', 6: '*fp16', 7: '*fp16', 8: '*fp16', 9: '*fp16', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 4096
rnumel = 768
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98653
Approved by: https://github.com/jansel