diff mbox series

RISC-V: Optimize branches with shifted immediate operands

Message ID DB9PR08MB6634DBA4BE2AC869CCC2EAD2B7922@DB9PR08MB6634.eurprd08.prod.outlook.com
State New
Headers show
Series RISC-V: Optimize branches with shifted immediate operands | expand

Commit Message

Jovan Vukic Sept. 2, 2024, 1:52 p.m. UTC
The patch adds a new instruction pattern to handle conditional branches with equality checks between shifted arithmetic operands. This pattern optimizes the use of shifted constants (with trailing zeros), making it more efficient.

For the C code:
void f5(long long a) {
  if ((a & 0x2120000) == 0x2000000)
    g();
}

before the patch, the assembly code was:
f5:
      li    a5,34734080
      and   a0,a0,a5
      li    a5,33554432
      beq   a0,a5,.L21
      ret

and after the patch the assembly is:
f5:
      srli  a5,a0,17
      andi  a5,a5,265
      li    a4,256
      beq   a5,a4,.L21
      ret

Tested on both RV32 and RV64 with no regressions.

2024-09-02  Jovan Vukic  <Jovan.Vukic@rt-rk.com>

gcc/ChangeLog:
      PR target/113248
      * config/riscv/riscv.md (*branch<ANYI:mode>_shiftedarith_equals_shifted): New pattern.

gcc/testsuite/ChangeLog:
      PR target/113248
      * gcc.target/riscv/branch-1.c: Additional tests.

---
 gcc/config/riscv/riscv.md                 | 32 +++++++++++++++++++++++
 gcc/testsuite/gcc.target/riscv/branch-1.c | 16 +++++++++---
 2 files changed, 45 insertions(+), 3 deletions(-)

--
2.43.0
CONFIDENTIALITY: The contents of this e-mail are confidential and intended only for the above addressee(s). If you are not the intended recipient, or the person responsible for delivering it to the intended recipient, copying or delivering it to anyone else or using it in any unauthorized manner is prohibited and may be unlawful. If you receive this e-mail by mistake, please notify the sender and the systems administrator at straymail@rt-rk.com immediately.

Comments

Jeff Law Sept. 3, 2024, 5:20 p.m. UTC | #1
On 9/2/24 7:52 AM, Jovan Vukic wrote:
> The patch adds a new instruction pattern to handle conditional branches 
> with equality checks between shifted arithmetic operands. This pattern 
> optimizes the use of shifted constants (with trailing zeros), making it 
> more efficient.
> 
> For the C code:
> void f5(long long a) {
>    if ((a & 0x2120000) == 0x2000000)
>      g();
> }
> 
> before the patch, the assembly code was:
> f5:
>       li    a5,34734080
>       and   a0,a0,a5
>       li    a5,33554432
>       beq   a0,a5,.L21
>       ret
> 
> and after the patch the assembly is:
> f5:
>       srli  a5,a0,17
>       andi  a5,a5,265
>       li    a4,256
>       beq   a5,a4,.L21
>       ret
> 
> Tested on both RV32 and RV64 with no regressions.
> 
> 2024-09-02  Jovan Vukic  <Jovan.Vukic@rt-rk.com>
> 
> gcc/ChangeLog:
>       PR target/113248
>       * config/riscv/riscv.md 
> (*branch<ANYI:mode>_shiftedarith_equals_shifted): New pattern.
> 
> gcc/testsuite/ChangeLog:
>       PR target/113248
>       * gcc.target/riscv/branch-1.c: Additional tests.
> 
> ---
>   gcc/config/riscv/riscv.md                 | 32 +++++++++++++++++++++++
>   gcc/testsuite/gcc.target/riscv/branch-1.c | 16 +++++++++---
>   2 files changed, 45 insertions(+), 3 deletions(-)
> 
> diff --git a/gcc/config/riscv/riscv.md b/gcc/config/riscv/riscv.md
> index 3289ed2155a..c98a66dbc7c 100644
> --- a/gcc/config/riscv/riscv.md
> +++ b/gcc/config/riscv/riscv.md
> @@ -3126,6 +3126,38 @@
>   }
>   [(set_attr "type" "branch")])
> +(define_insn_and_split "*branch<ANYI:mode>_shiftedarith_equals_shifted"
> +  [(set (pc)
> +     (if_then_else (match_operator 1 "equality_operator"
> +                  [(and:ANYI (match_operand:ANYI 2 "register_operand" "r")
> +                         (match_operand 3 "shifted_const_arith_operand" 
> "i"))
> +                 (match_operand 4 "shifted_const_arith_operand" "i")])
> +      (label_ref (match_operand 0 "" ""))
> +      (pc)))
> +   (clobber (match_scratch:X 5 "=&r"))
> +   (clobber (match_scratch:X 6 "=&r"))]
So match_operator works and I'm guessing you used it due to the its use 
in the existing *branch<ANYI:mode>_shiftedarith_equals_zero pattern.

It's worth noting there is a newer way which is usually slightly simpler 
than a match_operator.  Specifically code iterators.  After defining the 
iterator, you can use it in a pattern just like a simple RTL code.  So 
as an example:

> (define_insn "*<optab><mode>3" 
>   [(set (match_operand:X                0 "register_operand" "=r,r")
>         (any_or:X (match_operand:X 1 "register_operand" "%r,r")
>                        (match_operand:X 2 "arith_operand"    " r,I")))]
>   ""
>   "<insn>%i2\t%0,%1,%2" 
>   [(set_attr "type" "logical")
>    (set_attr "mode" "<MODE>")])
Note the "any_or" reference.  That's a code iterator that expands to ior 
and xor, trivially allowing the pattern to match both cases.  The <insn> 
and <optab> will map the xor/ior to the right assembly mnemonic and the 
optab name.  The definition of any_or,  as well as the mapping iterators 
are all kept in iterators.md.


I don't think you necessary need to change your patch, I'm just pointing 
out there's a newer way to do this rather than use a match_operator.

--



So from a correctness standpoint, after further review, I'm not as 
concerned about the subreg in the output template.   I'm a little 
concerned that this pattern will generate unrecognized insns.

The pattern uses shifted_const_arith_operand, which is good as it 
validates that the constant, if normalized by shifting away its trailing 
zeros fits in a simm12.

But the normalization you're doing on the two constants is limited by 
the smaller of trailing zero counts.  So operands2 might be 0x8100 which 
requires an 8 bit shift for normalization.  operands3 might be 0x81000 
which requires a 12 bit shift for normalization.  In that case we'll use 
8 as our shift count for normalization, resulting in:

0x8100 >> 8 = 0x81, a valid small operand
0x81000 >> 8 = 0x810, not a valid small operand.


I think that'll generate invalid RTL at split time.

What I think you need to do is in the main predicate (the same place 
you're currently !SMALL_OPERAND (INTVAL (operands[3]))), you'll need to 
check that both operands are SMALL_OPERAND after normalization.

I'd suggest putting that check into a little function rather than trying 
to do it all inline.  I wouldn't be surprised if you could have that 
little function also be used in the C fragment which sets up operands8..10.


But I think you're on a good path.



Jeff

ps.  Assuming I'm right, it would seem like a negative test with 0x8100 
and 0x81000 as the constants would be useful.
Jovan Vukic Sept. 5, 2024, 3:03 p.m. UTC | #2
> It's worth noting there is a newer way which is usually slightly simpler
> than a match_operator. Specifically code iterators.

Thank you for the very detailed feedback. It is not a problem to add code iterators. I would add iterators for "eq" and "ne" in riscv/iterators.md since they don't currently exist:

> (define_code_iterator any_eq [eq ne])

I would also add new <optab> values for "eq" and "ne". I assume it would be best to submit the patch again as version 2 with these changes.

> The pattern uses shifted_const_arith_operand, which is good as it
> validates that the constant, if normalized by shifting away its trailing
> zeros fits in a simm12.
>
> But the normalization you're doing on the two constants is limited by
> the smaller of trailing zero counts.  So operands2 might be 0x8100 which
> requires an 8 bit shift for normalization.  operands3 might be 0x81000
> which requires a 12 bit shift for normalization.  In that case we'll use
> 8 as our shift count for normalization, resulting in:
>
> 0x8100 >> 8 = 0x81, a valid small operand
> 0x81000 >> 8 = 0x810, not a valid small operand.
>
>
> I think that'll generate invalid RTL at split time.
>
> What I think you need to do is in the main predicate (the same place
> you're currently !SMALL_OPERAND (INTVAL (operands[3]))), you'll need to
> check that both operands are SMALL_OPERAND after normalization.

Regarding the second issue, thanks again for the clear explanation. While at first this might seem like a problem, I believe these cases won't actually be a problem.

The comparisons you mentioned, (x & 0x81000) == 0x8100 and (x & 0x8100) == 0x81000, will always evaluate as false, and this pattern will never be used for them (https://godbolt.org/z/Y11EGMb4f).

Even in general, I'm quite sure we will never encounter an operand, after shifting, greater than 2^11 (i.e. not a SMALL_OPERAND). I will provide my reasoning below, but if you find it incorrect, or have any counterexamples, I would be happy to make the requested changes, add the mentioned check and submit that as PATCH v2.

Lets consider the general expression (x & c1) == c2, where t1 and t2 represent the counts of trailing zeros in each constant. There are three cases to consider:
1. When t1 == t2: The pattern works fine, with no edge cases.
2. When t1 > t2: The expression will always evaluate as false, and the pattern won’t even be considered. For example, (x & 0x81000) == 0x8100.
3. When t1 < t2: In this case:
   - c1 must be of the form 0x0KLM00 (where the highest bit of K cannot be set) to meet the shifted_const_arith_operand constraint, ensuring SMALL_OPERAND(0x0KLM) == true (i.e. 0x0KLM < 2^11).
   - To prevent the expression from immediately evaluating as false, c2 must be in the form 0x0PQ<0bxxx0>00, where this value has to have only 0 or 1 in bit positions where c1 has 1 (and 0 elsewhere). Otherwise, (x & c1) == c2 would instantly be false, so this pattern wouldn’t be used. Lets call this "the critical condition".
   - After shifting c1 and c2, we have c1 == 0xKLM and c2 == 0xPQ<0bxxx0>, assuming the LSB of M is set to 1.
   - Due to "the critical condition", c2 == 0xPQ<0bxxx0> cannot have the highest bit of P set to 1. Otherwise, (x & c1) == c2 would immediately evaluate as false, since 0xKLM is guaranteed not to have the highest bit of K set to 1. This guarantees that SMALL_OPERAND(0xPQ<0bxxx0>) will always be true (i.e. c2 < 2^11).
CONFIDENTIALITY: The contents of this e-mail are confidential and intended only for the above addressee(s). If you are not the intended recipient, or the person responsible for delivering it to the intended recipient, copying or delivering it to anyone else or using it in any unauthorized manner is prohibited and may be unlawful. If you receive this e-mail by mistake, please notify the sender and the systems administrator at straymail@rt-rk.com immediately.
diff mbox series

Patch

diff --git a/gcc/config/riscv/riscv.md b/gcc/config/riscv/riscv.md
index 3289ed2155a..c98a66dbc7c 100644
--- a/gcc/config/riscv/riscv.md
+++ b/gcc/config/riscv/riscv.md
@@ -3126,6 +3126,38 @@ 
 }
 [(set_attr "type" "branch")])

+(define_insn_and_split "*branch<ANYI:mode>_shiftedarith_equals_shifted"
+  [(set (pc)
+     (if_then_else (match_operator 1 "equality_operator"
+                  [(and:ANYI (match_operand:ANYI 2 "register_operand" "r")
+                         (match_operand 3 "shifted_const_arith_operand" "i"))
+                 (match_operand 4 "shifted_const_arith_operand" "i")])
+      (label_ref (match_operand 0 "" ""))
+      (pc)))
+   (clobber (match_scratch:X 5 "=&r"))
+   (clobber (match_scratch:X 6 "=&r"))]
+  "!SMALL_OPERAND (INTVAL (operands[3]))
+    && !SMALL_OPERAND (INTVAL (operands[4]))"
+  "#"
+  "&& reload_completed"
+  [(set (match_dup 5) (lshiftrt:X (subreg:X (match_dup 2) 0) (match_dup 8)))
+   (set (match_dup 5) (and:X (match_dup 5) (match_dup 9)))
+   (set (match_dup 6) (match_dup 10))
+   (set (pc) (if_then_else (match_op_dup 1 [(match_dup 5) (match_dup 6)])
+                    (label_ref (match_dup 0)) (pc)))]
+{
+  HOST_WIDE_INT mask1 = INTVAL (operands[3]);
+  HOST_WIDE_INT mask2 = INTVAL (operands[4]);
+  int trailing = (ctz_hwi (mask1) > ctz_hwi (mask2))
+           ? ctz_hwi (mask2)
+           : ctz_hwi (mask1);
+
+  operands[8] = GEN_INT (trailing);
+  operands[9] = GEN_INT (mask1 >> trailing);
+  operands[10] = GEN_INT (mask2 >> trailing);
+}
+[(set_attr "type" "branch")])
+
 (define_insn_and_split "*branch<ANYI:mode>_shiftedmask_equals_zero"
   [(set (pc)
      (if_then_else (match_operator 1 "equality_operator"
diff --git a/gcc/testsuite/gcc.target/riscv/branch-1.c b/gcc/testsuite/gcc.target/riscv/branch-1.c
index b4a3a946379..e09328fe705 100644
--- a/gcc/testsuite/gcc.target/riscv/branch-1.c
+++ b/gcc/testsuite/gcc.target/riscv/branch-1.c
@@ -28,10 +28,20 @@  void f4(long long a)
     g();
 }

+void f5(long long a) {
+  if ((a & 0x2120000) == 0x2000000)
+    g();
+}
+
+void f6(long long a) {
+  if ((a & 0x70000000) == 0x30000000)
+    g();
+}
+
 /* { dg-final { scan-assembler-times "slli\t" 2 } } */
-/* { dg-final { scan-assembler-times "srli\t" 3 } } */
-/* { dg-final { scan-assembler-times "andi\t" 1 } } */
-/* { dg-final { scan-assembler-times "\tli\t" 1 } } */
+/* { dg-final { scan-assembler-times "srli\t" 5 } } */
+/* { dg-final { scan-assembler-times "andi\t" 3 } } */
+/* { dg-final { scan-assembler-times "\tli\t" 3 } } */
 /* { dg-final { scan-assembler-not "addi\t" } } */
 /* { dg-final { scan-assembler-not "and\t" } } */