diff mbox series

[3/4] Add double reduction support for SLP vectorization

Message ID 20240605062333.6302838F14FE@sourceware.org
State New
Headers show
Series [1/4] Relax COND_EXPR reduction vectorization SLP restriction | expand

Commit Message

Richard Biener June 5, 2024, 6:23 a.m. UTC
The following makes double reduction vectorization work when
using (single-lane) SLP vectorization.

	* tree-vect-loop.cc (vect_analyze_scalar_cycles_1): Queue
	double reductions in LOOP_VINFO_REDUCTIONS.
	(vect_create_epilog_for_reduction): Remove asserts disabling
	SLP for double reductions.
	(vectorizable_reduction): Analyze SLP double reductions
	only once and start off the correct places.
	* tree-vect-slp.cc (vect_get_and_check_slp_defs): Allow
	vect_double_reduction_def.
	(vect_build_slp_tree_2): Fix condition for the ignored
	reduction initial values.
	* tree-vect-stmts.cc (vect_analyze_stmt): Allow
	vect_double_reduction_def.
---
 gcc/tree-vect-loop.cc  | 35 +++++++++++++++++++++++++----------
 gcc/tree-vect-slp.cc   |  3 ++-
 gcc/tree-vect-stmts.cc |  4 ++++
 3 files changed, 31 insertions(+), 11 deletions(-)
diff mbox series

Patch

diff --git a/gcc/tree-vect-loop.cc b/gcc/tree-vect-loop.cc
index ccd6acef5c5..b9e8e9b5559 100644
--- a/gcc/tree-vect-loop.cc
+++ b/gcc/tree-vect-loop.cc
@@ -685,6 +685,8 @@  vect_analyze_scalar_cycles_1 (loop_vec_info loop_vinfo, class loop *loop,
 
               STMT_VINFO_DEF_TYPE (stmt_vinfo) = vect_double_reduction_def;
 	      STMT_VINFO_DEF_TYPE (reduc_stmt_info) = vect_double_reduction_def;
+	      /* Make it accessible for SLP vectorization.  */
+	      LOOP_VINFO_REDUCTIONS (loop_vinfo).safe_push (reduc_stmt_info);
             }
           else
             {
@@ -5975,7 +5977,6 @@  vect_create_epilog_for_reduction (loop_vec_info loop_vinfo,
   stmt_vec_info rdef_info = stmt_info;
   if (STMT_VINFO_DEF_TYPE (stmt_info) == vect_double_reduction_def)
     {
-      gcc_assert (!slp_node);
       double_reduc = true;
       stmt_info = loop_vinfo->lookup_def (gimple_phi_arg_def
 					    (stmt_info->stmt, 0));
@@ -6020,7 +6021,7 @@  vect_create_epilog_for_reduction (loop_vec_info loop_vinfo,
     {
       outer_loop = loop;
       loop = loop->inner;
-      gcc_assert (!slp_node && double_reduc);
+      gcc_assert (double_reduc);
     }
 
   vectype = STMT_VINFO_REDUC_VECTYPE (reduc_info);
@@ -6035,7 +6036,7 @@  vect_create_epilog_for_reduction (loop_vec_info loop_vinfo,
 	 for induc_val, use initial_def.  */
       if (STMT_VINFO_REDUC_TYPE (reduc_info) == INTEGER_INDUC_COND_REDUCTION)
 	induc_val = STMT_VINFO_VEC_INDUC_COND_INITIAL_VAL (reduc_info);
-      /* ???  Coverage for double_reduc and 'else' isn't clear.  */
+      /* ???  Coverage for 'else' isn't clear.  */
     }
   else
     {
@@ -7605,15 +7606,16 @@  vectorizable_reduction (loop_vec_info loop_vinfo,
       STMT_VINFO_TYPE (stmt_info) = reduc_vec_info_type;
       return true;
     }
-  if (slp_node)
-    {
-      slp_node_instance->reduc_phis = slp_node;
-      /* ???  We're leaving slp_node to point to the PHIs, we only
-	 need it to get at the number of vector stmts which wasn't
-	 yet initialized for the instance root.  */
-    }
   if (STMT_VINFO_DEF_TYPE (stmt_info) == vect_double_reduction_def)
     {
+      if (gimple_bb (stmt_info->stmt) != loop->header)
+	{
+	  /* For SLP we arrive here for both the inner loop LC PHI and
+	     the outer loop PHI.  The latter is what we want to analyze
+	     the reduction with.  */
+	  gcc_assert (slp_node);
+	  return true;
+	}
       use_operand_p use_p;
       gimple *use_stmt;
       bool res = single_imm_use (gimple_phi_result (stmt_info->stmt),
@@ -7622,6 +7624,14 @@  vectorizable_reduction (loop_vec_info loop_vinfo,
       phi_info = loop_vinfo->lookup_stmt (use_stmt);
     }
 
+  if (slp_node)
+    {
+      slp_node_instance->reduc_phis = slp_node;
+      /* ???  We're leaving slp_node to point to the PHIs, we only
+	 need it to get at the number of vector stmts which wasn't
+	 yet initialized for the instance root.  */
+    }
+
   /* PHIs should not participate in patterns.  */
   gcc_assert (!STMT_VINFO_RELATED_STMT (phi_info));
   gphi *reduc_def_phi = as_a <gphi *> (phi_info->stmt);
@@ -7637,6 +7647,11 @@  vectorizable_reduction (loop_vec_info loop_vinfo,
   bool only_slp_reduc_chain = true;
   stmt_info = NULL;
   slp_tree slp_for_stmt_info = slp_node ? slp_node_instance->root : NULL;
+  /* For double-reductions we start SLP analysis at the inner loop LC PHI
+     which is the def of the outer loop live stmt.  */
+  if (STMT_VINFO_DEF_TYPE (reduc_info) == vect_double_reduction_def
+      && slp_node)
+    slp_for_stmt_info = SLP_TREE_CHILDREN (slp_for_stmt_info)[0];
   while (reduc_def != PHI_RESULT (reduc_def_phi))
     {
       stmt_vec_info def = loop_vinfo->lookup_def (reduc_def);
diff --git a/gcc/tree-vect-slp.cc b/gcc/tree-vect-slp.cc
index ba1190c7155..7e3d0107b4e 100644
--- a/gcc/tree-vect-slp.cc
+++ b/gcc/tree-vect-slp.cc
@@ -778,6 +778,7 @@  vect_get_and_check_slp_defs (vec_info *vinfo, unsigned char swap,
 	    case vect_constant_def:
 	    case vect_internal_def:
 	    case vect_reduction_def:
+	    case vect_double_reduction_def:
 	    case vect_induction_def:
 	    case vect_nested_cycle:
 	    case vect_first_order_recurrence:
@@ -1906,7 +1907,7 @@  vect_build_slp_tree_2 (vec_info *vinfo, slp_tree node,
 	    class loop *loop = LOOP_VINFO_LOOP (loop_vinfo);
 	    /* Reduction initial values are not explicitely represented.  */
 	    if (def_type != vect_first_order_recurrence
-		&& !nested_in_vect_loop_p (loop, stmt_info))
+		&& gimple_bb (stmt_info->stmt) == loop->header)
 	      skip_args[loop_preheader_edge (loop)->dest_idx] = true;
 	    /* Reduction chain backedge defs are filled manually.
 	       ???  Need a better way to identify a SLP reduction chain PHI.
diff --git a/gcc/tree-vect-stmts.cc b/gcc/tree-vect-stmts.cc
index c82381e799e..5098b7fab6a 100644
--- a/gcc/tree-vect-stmts.cc
+++ b/gcc/tree-vect-stmts.cc
@@ -13260,6 +13260,10 @@  vect_analyze_stmt (vec_info *vinfo,
 			 || relevance == vect_used_only_live));
          break;
 
+      case vect_double_reduction_def:
+	gcc_assert (!bb_vinfo && node);
+	break;
+
       case vect_induction_def:
       case vect_first_order_recurrence:
 	gcc_assert (!bb_vinfo);