diff mbox series

[2/2,v2] tree-optimization/116575 - SLP masked load-lanes discovery

Message ID 20241024142025.2D1D3136F5@imap1.dmz-prg2.suse.org
State New
Headers show
Series [1/2,v2] Relax vect_check_scalar_mask check | expand

Commit Message

Richard Biener Oct. 24, 2024, 2:20 p.m. UTC
The following implements masked load-lane discovery for SLP.  The
challenge here is that a masked load has a full-width mask with
group-size number of elements when this becomes a masked load-lanes
instruction one mask element gates all group members.  We already
have some discovery hints in place, namely STMT_VINFO_SLP_VECT_ONLY
to guard non-uniform masks, but we need to choose a way for SLP
discovery to handle possible masked load-lanes SLP trees.

I have this time chosen to handle load-lanes discovery where we
have performed permute optimization already and conveniently got
the graph with predecessor edges built.  This is because unlike
non-masked loads masked loads with a load_permutation are never
produced by SLP discovery (because load permutation handling doesn't
handle un-permuting the mask) and thus the load-permutation lowering
which handles non-masked load-lanes discovery doesn't trigger.

With this SLP discovery for a possible masked load-lanes, thus
a masked load with uniform mask, produces a splat of a single-lane
sub-graph as the mask SLP operand.  This is a representation that
shouldn't pessimize the mask load case and allows the masked load-lanes
transform to simply elide this splat.

This fixes the aarch64-sve.exp mask_struct_load*.c testcases with
--param vect-force-slp=1

Re-bootstrap & regtest running on x86_64-unknown-linux-gnu, the
observed CI FAILs are gone.

	PR tree-optimization/116575
	* tree-vect-slp.cc (vect_get_and_check_slp_defs): Handle
	gaps, aka NULL scalar stmt.
	(vect_build_slp_tree_2): Allow gaps in the middle of a
	grouped mask load.  When the mask of a grouped mask load
	is uniform do single-lane discovery for the mask and
	insert a splat VEC_PERM_EXPR node.
	(vect_optimize_slp_pass::decide_masked_load_lanes): New
	function.
	(vect_optimize_slp_pass::run): Call it.
---
 gcc/tree-vect-slp.cc | 141 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 138 insertions(+), 3 deletions(-)
diff mbox series

Patch

diff --git a/gcc/tree-vect-slp.cc b/gcc/tree-vect-slp.cc
index 53f5400a961..b192328e3eb 100644
--- a/gcc/tree-vect-slp.cc
+++ b/gcc/tree-vect-slp.cc
@@ -641,6 +641,16 @@  vect_get_and_check_slp_defs (vec_info *vinfo, unsigned char swap,
   unsigned int commutative_op = -1U;
   bool first = stmt_num == 0;
 
+  if (!stmt_info)
+    {
+      for (auto oi : *oprnds_info)
+	{
+	  oi->def_stmts.quick_push (NULL);
+	  oi->ops.quick_push (NULL_TREE);
+	}
+      return 0;
+    }
+
   if (!is_a<gcall *> (stmt_info->stmt)
       && !is_a<gassign *> (stmt_info->stmt)
       && !is_a<gphi *> (stmt_info->stmt))
@@ -2029,9 +2039,11 @@  vect_build_slp_tree_2 (vec_info *vinfo, slp_tree node,
 		    has_gaps = true;
 	      /* We cannot handle permuted masked loads directly, see
 		 PR114375.  We cannot handle strided masked loads or masked
-		 loads with gaps.  */
+		 loads with gaps unless the mask is uniform.  */
 	      if ((STMT_VINFO_GROUPED_ACCESS (stmt_info)
-		   && (DR_GROUP_GAP (first_stmt_info) != 0 || has_gaps))
+		   && (DR_GROUP_GAP (first_stmt_info) != 0
+		       || (has_gaps
+			   && STMT_VINFO_SLP_VECT_ONLY (first_stmt_info))))
 		  || STMT_VINFO_STRIDED_P (stmt_info))
 		{
 		  load_permutation.release ();
@@ -2054,7 +2066,12 @@  vect_build_slp_tree_2 (vec_info *vinfo, slp_tree node,
 		  unsigned i = 0;
 		  for (stmt_vec_info si = first_stmt_info;
 		       si; si = DR_GROUP_NEXT_ELEMENT (si))
-		    stmts2[i++] = si;
+		    {
+		      if (si != first_stmt_info)
+			for (unsigned k = 1; k < DR_GROUP_GAP (si); ++k)
+			  stmts2[i++] = NULL;
+		      stmts2[i++] = si;
+		    }
 		  bool *matches2 = XALLOCAVEC (bool, dr_group_size);
 		  slp_tree unperm_load
 		    = vect_build_slp_tree (vinfo, stmts2, dr_group_size,
@@ -2683,6 +2700,46 @@  out:
 	  continue;
 	}
 
+      /* When we have a masked load with uniform mask discover this
+	 as a single-lane mask with a splat permute.  This way we can
+	 recognize this as a masked load-lane by stripping the splat.  */
+      if (is_a <gcall *> (STMT_VINFO_STMT (stmt_info))
+	  && gimple_call_internal_p (STMT_VINFO_STMT (stmt_info),
+				     IFN_MASK_LOAD)
+	  && STMT_VINFO_GROUPED_ACCESS (stmt_info)
+	  && ! STMT_VINFO_SLP_VECT_ONLY (DR_GROUP_FIRST_ELEMENT (stmt_info)))
+	{
+	  vec<stmt_vec_info> def_stmts2;
+	  def_stmts2.create (1);
+	  def_stmts2.quick_push (oprnd_info->def_stmts[0]);
+	  child = vect_build_slp_tree (vinfo, def_stmts2, 1,
+				       &this_max_nunits,
+				       matches, limit,
+				       &this_tree_size, bst_map);
+	  if (child)
+	    {
+	      slp_tree pnode = vect_create_new_slp_node (1, VEC_PERM_EXPR);
+	      SLP_TREE_VECTYPE (pnode) = SLP_TREE_VECTYPE (child);
+	      SLP_TREE_LANES (pnode) = group_size;
+	      SLP_TREE_SCALAR_STMTS (pnode).create (group_size);
+	      SLP_TREE_LANE_PERMUTATION (pnode).create (group_size);
+	      for (unsigned k = 0; k < group_size; ++k)
+		{
+		  SLP_TREE_SCALAR_STMTS (pnode)
+		    .quick_push (oprnd_info->def_stmts[0]);
+		  SLP_TREE_LANE_PERMUTATION (pnode)
+		    .quick_push (std::make_pair (0u, 0u));
+		}
+	      SLP_TREE_CHILDREN (pnode).quick_push (child);
+	      pnode->max_nunits = child->max_nunits;
+	      children.safe_push (pnode);
+	      oprnd_info->def_stmts = vNULL;
+	      continue;
+	    }
+	  else
+	    def_stmts2.release ();
+	}
+
       if ((child = vect_build_slp_tree (vinfo, oprnd_info->def_stmts,
 					group_size, &this_max_nunits,
 					matches, limit,
@@ -5462,6 +5519,9 @@  private:
   /* Clean-up.  */
   void remove_redundant_permutations ();
 
+  /* Masked load lanes discovery.  */
+  void decide_masked_load_lanes ();
+
   void dump ();
 
   vec_info *m_vinfo;
@@ -7090,6 +7150,80 @@  vect_optimize_slp_pass::dump ()
     }
 }
 
+/* Masked load lanes discovery.  */
+
+void
+vect_optimize_slp_pass::decide_masked_load_lanes ()
+{
+  for (auto v : m_vertices)
+    {
+      slp_tree node = v.node;
+      if (SLP_TREE_DEF_TYPE (node) != vect_internal_def
+	  || SLP_TREE_CODE (node) == VEC_PERM_EXPR)
+	continue;
+      stmt_vec_info stmt_info = SLP_TREE_REPRESENTATIVE (node);
+      if (! STMT_VINFO_GROUPED_ACCESS (stmt_info)
+	  /* The mask has to be uniform.  */
+	  || STMT_VINFO_SLP_VECT_ONLY (stmt_info)
+	  || ! is_a <gcall *> (STMT_VINFO_STMT (stmt_info))
+	  || ! gimple_call_internal_p (STMT_VINFO_STMT (stmt_info),
+				       IFN_MASK_LOAD))
+	continue;
+      stmt_info = DR_GROUP_FIRST_ELEMENT (stmt_info);
+      if (STMT_VINFO_STRIDED_P (stmt_info)
+	  || compare_step_with_zero (m_vinfo, stmt_info) <= 0
+	  || vect_load_lanes_supported (SLP_TREE_VECTYPE (node),
+					DR_GROUP_SIZE (stmt_info),
+					true) == IFN_LAST)
+	continue;
+
+      /* Uniform masks need to be suitably represented.  */
+      slp_tree mask = SLP_TREE_CHILDREN (node)[0];
+      if (SLP_TREE_CODE (mask) != VEC_PERM_EXPR
+	  || SLP_TREE_CHILDREN (mask).length () != 1)
+	continue;
+      bool match = true;
+      for (auto perm : SLP_TREE_LANE_PERMUTATION (mask))
+	if (perm.first != 0 || perm.second != 0)
+	  {
+	    match = false;
+	    break;
+	  }
+      if (!match)
+	continue;
+
+      /* Now see if the consumer side matches.  */
+      for (graph_edge *pred = m_slpg->vertices[node->vertex].pred;
+	   pred; pred = pred->pred_next)
+	{
+	  slp_tree pred_node = m_vertices[pred->src].node;
+	  /* All consumers should be a permute with a single outgoing lane.  */
+	  if (SLP_TREE_CODE (pred_node) != VEC_PERM_EXPR
+	      || SLP_TREE_LANES (pred_node) != 1)
+	    {
+	      match = false;
+	      break;
+	    }
+	  gcc_assert (SLP_TREE_CHILDREN (pred_node).length () == 1);
+	}
+      if (!match)
+	continue;
+      /* Now we can mark the nodes as to use load lanes.  */
+      node->ldst_lanes = true;
+      for (graph_edge *pred = m_slpg->vertices[node->vertex].pred;
+	   pred; pred = pred->pred_next)
+	m_vertices[pred->src].node->ldst_lanes = true;
+      /* The catch is we have to massage the mask.  We have arranged
+	 analyzed uniform masks to be represented by a splat VEC_PERM
+	 which we can now simply elide as we cannot easily re-do SLP
+	 discovery here.  */
+      slp_tree new_mask = SLP_TREE_CHILDREN (mask)[0];
+      SLP_TREE_REF_COUNT (new_mask)++;
+      SLP_TREE_CHILDREN (node)[0] = new_mask;
+      vect_free_slp_tree (mask);
+    }
+}
+
 /* Main entry point for the SLP graph optimization pass.  */
 
 void
@@ -7110,6 +7244,7 @@  vect_optimize_slp_pass::run ()
     }
   else
     remove_redundant_permutations ();
+  decide_masked_load_lanes ();
   free_graph (m_slpg);
 }