diff mbox series

[1/2] Enhance cse_insn to handle all-zeros and all-ones for vector mode.

Message ID 20240826073318.2521204-1-hongtao.liu@intel.com
State New
Headers show
Series [1/2] Enhance cse_insn to handle all-zeros and all-ones for vector mode. | expand

Commit Message

liuhongt Aug. 26, 2024, 7:33 a.m. UTC
Also try to handle redundant broadcasts when there's already a
broadcast to a bigger mode with exactly the same component value.
For broadcast, component mode needs to be the same.
For all-zeros/ones, only need to check the bigger mode.

Bootstrapped and regtested on x86_64-pc-linux-gnu{-m32,} and aarch64-linux-gnu{-m32,}.
OK for trunk?

gcc/ChangeLog:

	PR rtl-optimization/92080
	* cse.cc (cse_insn): Handle all-ones/all-zeros, and vec_dup
	with variables.
---
 gcc/cse.cc | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 79 insertions(+)

Comments

Richard Biener Aug. 26, 2024, 9:43 a.m. UTC | #1
On Mon, Aug 26, 2024 at 9:34 AM liuhongt <hongtao.liu@intel.com> wrote:
>
> Also try to handle redundant broadcasts when there's already a
> broadcast to a bigger mode with exactly the same component value.
> For broadcast, component mode needs to be the same.
> For all-zeros/ones, only need to check the bigger mode.
>
> Bootstrapped and regtested on x86_64-pc-linux-gnu{-m32,} and aarch64-linux-gnu{-m32,}.
> OK for trunk?
>
> gcc/ChangeLog:
>
>         PR rtl-optimization/92080
>         * cse.cc (cse_insn): Handle all-ones/all-zeros, and vec_dup
>         with variables.
> ---
>  gcc/cse.cc | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
>  1 file changed, 79 insertions(+)
>
> diff --git a/gcc/cse.cc b/gcc/cse.cc
> index 65794ac5f2c..baf90910b94 100644
> --- a/gcc/cse.cc
> +++ b/gcc/cse.cc
> @@ -4870,6 +4870,50 @@ cse_insn (rtx_insn *insn)
>             }
>         }
>
> +      /* Try to handle special const_vector with elt 0 or -1.
> +        They can be represented with different modes, and can be cse.  */
> +      if (src_const && src_related == 0 && CONST_VECTOR_P (src_const)
> +         && (src_const == CONST0_RTX (mode)
> +             || src_const == CONSTM1_RTX (mode))
> +         && GET_MODE_CLASS (mode) == MODE_VECTOR_INT)
> +       {
> +         machine_mode mode_iter;
> +
> +         for (int l = 0; l != 2; l++)
> +           {
> +             FOR_EACH_MODE_IN_CLASS (mode_iter, MODE_VECTOR_INT)
> +               {
> +                 if (maybe_lt (GET_MODE_SIZE (mode_iter),
> +                               GET_MODE_SIZE (mode)))
> +                   continue;
> +
> +                 rtx src_const_iter = (src_const == CONST0_RTX (mode)
> +                                       ? CONST0_RTX (mode_iter)
> +                                       : CONSTM1_RTX (mode_iter));
> +
> +                 struct table_elt *const_elt
> +                   = lookup (src_const_iter, HASH (src_const_iter, mode_iter),
> +                             mode_iter);
> +
> +                 if (const_elt == 0)
> +                   continue;
> +
> +                 for (const_elt = const_elt->first_same_value;
> +                      const_elt; const_elt = const_elt->next_same_value)
> +                   if (REG_P (const_elt->exp))
> +                     {
> +                       src_related = gen_lowpart (mode, const_elt->exp);
> +                       break;
> +                     }
> +
> +                 if (src_related != 0)
> +                   break;
> +               }
> +             if (src_related != 0)
> +               break;
> +           }
> +       }
> +
>        /* See if we have a CONST_INT that is already in a register in a
>          wider mode.  */
>
> @@ -5041,6 +5085,41 @@ cse_insn (rtx_insn *insn)
>             }
>         }
>
> +      /* Try to find something like (vec_dup:v16si (reg:c))
> +            for (vec_dup:v8si (reg:c)).  */
> +      if (src_related == 0
> +         && VECTOR_MODE_P (mode)
> +         && GET_CODE (src) == VEC_DUPLICATE)
> +       {
> +         poly_uint64 nunits = GET_MODE_NUNITS (GET_MODE (src)) * 2;
> +         rtx inner_elt = XEXP (src, 0);
> +         machine_mode result_mode;
> +         struct table_elt *src_related_elt = NULL;;
> +         while (related_vector_mode (mode, GET_MODE_INNER (mode),
> +                                     nunits).exists (&result_mode))
> +           {
> +             rtx vec_dup = gen_rtx_VEC_DUPLICATE (result_mode, inner_elt);
> +             struct table_elt* tmp = lookup (vec_dup, HASH (vec_dup, result_mode),
> +                                             result_mode);
> +             if (tmp)
> +               src_related_elt = tmp;

You are possibly overwriting src_related_elt - I'd suggest to either break
here or do the loop below for each found elt?

> +             nunits *= 2;
> +           }
> +
> +         if (src_related_elt)
> +           {
> +             for (src_related_elt = src_related_elt->first_same_value;
> +                  src_related_elt;
> +                  src_related_elt = src_related_elt->next_same_value)
> +               if (REG_P (src_related_elt->exp))
> +                 {
> +                   src_related = gen_lowpart (mode, src_related_elt->exp);

Do we know that will always succeed?

> +                   break;
> +                 }
> +           }
> +       }

So on the GIMPLE side we are trying to handle such cases by maintaining
only a single element in the hashtables, thus hash and compare them
the same - them in this case (vec_dup:M (reg:c)) and (vec_dup:N (reg:c)),
leaving it up to the consumer to reject or pun mismatches.

For constants that would hold even more - note CSEing vs. duplicating
constants might not be universally good.

Richard.

> +
>        if (src == src_folded)
>         src_folded = 0;
>
> --
> 2.31.1
>
diff mbox series

Patch

diff --git a/gcc/cse.cc b/gcc/cse.cc
index 65794ac5f2c..baf90910b94 100644
--- a/gcc/cse.cc
+++ b/gcc/cse.cc
@@ -4870,6 +4870,50 @@  cse_insn (rtx_insn *insn)
 	    }
 	}
 
+      /* Try to handle special const_vector with elt 0 or -1.
+	 They can be represented with different modes, and can be cse.  */
+      if (src_const && src_related == 0 && CONST_VECTOR_P (src_const)
+	  && (src_const == CONST0_RTX (mode)
+	      || src_const == CONSTM1_RTX (mode))
+	  && GET_MODE_CLASS (mode) == MODE_VECTOR_INT)
+	{
+	  machine_mode mode_iter;
+
+	  for (int l = 0; l != 2; l++)
+	    {
+	      FOR_EACH_MODE_IN_CLASS (mode_iter, MODE_VECTOR_INT)
+		{
+		  if (maybe_lt (GET_MODE_SIZE (mode_iter),
+				GET_MODE_SIZE (mode)))
+		    continue;
+
+		  rtx src_const_iter = (src_const == CONST0_RTX (mode)
+					? CONST0_RTX (mode_iter)
+					: CONSTM1_RTX (mode_iter));
+
+		  struct table_elt *const_elt
+		    = lookup (src_const_iter, HASH (src_const_iter, mode_iter),
+			      mode_iter);
+
+		  if (const_elt == 0)
+		    continue;
+
+		  for (const_elt = const_elt->first_same_value;
+		       const_elt; const_elt = const_elt->next_same_value)
+		    if (REG_P (const_elt->exp))
+		      {
+			src_related = gen_lowpart (mode, const_elt->exp);
+			break;
+		      }
+
+		  if (src_related != 0)
+		    break;
+		}
+	      if (src_related != 0)
+		break;
+	    }
+	}
+
       /* See if we have a CONST_INT that is already in a register in a
 	 wider mode.  */
 
@@ -5041,6 +5085,41 @@  cse_insn (rtx_insn *insn)
 	    }
 	}
 
+      /* Try to find something like (vec_dup:v16si (reg:c))
+	     for (vec_dup:v8si (reg:c)).  */
+      if (src_related == 0
+	  && VECTOR_MODE_P (mode)
+	  && GET_CODE (src) == VEC_DUPLICATE)
+	{
+	  poly_uint64 nunits = GET_MODE_NUNITS (GET_MODE (src)) * 2;
+	  rtx inner_elt = XEXP (src, 0);
+	  machine_mode result_mode;
+	  struct table_elt *src_related_elt = NULL;;
+	  while (related_vector_mode (mode, GET_MODE_INNER (mode),
+				      nunits).exists (&result_mode))
+	    {
+	      rtx vec_dup = gen_rtx_VEC_DUPLICATE (result_mode, inner_elt);
+	      struct table_elt* tmp = lookup (vec_dup, HASH (vec_dup, result_mode),
+					      result_mode);
+	      if (tmp)
+		src_related_elt = tmp;
+	      nunits *= 2;
+	    }
+
+	  if (src_related_elt)
+	    {
+	      for (src_related_elt = src_related_elt->first_same_value;
+		   src_related_elt;
+		   src_related_elt = src_related_elt->next_same_value)
+		if (REG_P (src_related_elt->exp))
+		  {
+		    src_related = gen_lowpart (mode, src_related_elt->exp);
+		    break;
+		  }
+	    }
+	}
+
+
       if (src == src_folded)
 	src_folded = 0;