diff mbox series

aarch64: Restore SVE WHILE costing

Message ID mptled8lszh.fsf@arm.com
State New
Headers show
Series aarch64: Restore SVE WHILE costing | expand

Commit Message

Richard Sandiford Sept. 14, 2023, 2:38 p.m. UTC
AArch64 previously costed WHILELO instructions on the first call
to add_stmt_cost.  This was because, at the time, only add_stmt_cost
had access to the loop_vec_info.

However, after the AVX512 changes, we only calculate the masks later.
This patch moves the WHILELO costing to finish_cost, which is in any
case a more logical place for it to be.  It also means that we can
check the final decision about whether to use predicated loops.

Tested on aarch64-linux-gnu & applied.

Richard


gcc/
	* config/aarch64/aarch64.cc (aarch64_vector_costs::analyze_loop_info):
	Move WHILELO handling to...
	(aarch64_vector_costs::finish_cost): ...here.  Check whether the
	vectorizer has decided to use a predicated loop.

gcc/testsuite/
	* gcc.target/aarch64/sve/cost_model_15.c: New test.
---
 gcc/config/aarch64/aarch64.cc                 | 36 ++++++++++---------
 .../gcc.target/aarch64/sve/cost_model_15.c    | 13 +++++++
 2 files changed, 32 insertions(+), 17 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/aarch64/sve/cost_model_15.c
diff mbox series

Patch

diff --git a/gcc/config/aarch64/aarch64.cc b/gcc/config/aarch64/aarch64.cc
index 3739a44bfd9..0962fc4f56e 100644
--- a/gcc/config/aarch64/aarch64.cc
+++ b/gcc/config/aarch64/aarch64.cc
@@ -16310,22 +16310,6 @@  aarch64_vector_costs::analyze_loop_vinfo (loop_vec_info loop_vinfo)
   /* Detect whether we're vectorizing for SVE and should apply the unrolling
      heuristic described above m_unrolled_advsimd_niters.  */
   record_potential_advsimd_unrolling (loop_vinfo);
-
-  /* Record the issue information for any SVE WHILE instructions that the
-     loop needs.  */
-  if (!m_ops.is_empty () && !LOOP_VINFO_MASKS (loop_vinfo).is_empty ())
-    {
-      unsigned int num_masks = 0;
-      rgroup_controls *rgm;
-      unsigned int num_vectors_m1;
-      FOR_EACH_VEC_ELT (LOOP_VINFO_MASKS (loop_vinfo).rgc_vec,
-			num_vectors_m1, rgm)
-	if (rgm->type)
-	  num_masks += num_vectors_m1 + 1;
-      for (auto &ops : m_ops)
-	if (auto *issue = ops.sve_issue_info ())
-	  ops.pred_ops += num_masks * issue->while_pred_ops;
-    }
 }
 
 /* Implement targetm.vectorize.builtin_vectorization_cost.  */
@@ -17507,9 +17491,27 @@  adjust_body_cost (loop_vec_info loop_vinfo,
 void
 aarch64_vector_costs::finish_cost (const vector_costs *uncast_scalar_costs)
 {
+  /* Record the issue information for any SVE WHILE instructions that the
+     loop needs.  */
+  loop_vec_info loop_vinfo = dyn_cast<loop_vec_info> (m_vinfo);
+  if (!m_ops.is_empty ()
+      && loop_vinfo
+      && LOOP_VINFO_FULLY_MASKED_P (loop_vinfo))
+    {
+      unsigned int num_masks = 0;
+      rgroup_controls *rgm;
+      unsigned int num_vectors_m1;
+      FOR_EACH_VEC_ELT (LOOP_VINFO_MASKS (loop_vinfo).rgc_vec,
+			num_vectors_m1, rgm)
+	if (rgm->type)
+	  num_masks += num_vectors_m1 + 1;
+      for (auto &ops : m_ops)
+	if (auto *issue = ops.sve_issue_info ())
+	  ops.pred_ops += num_masks * issue->while_pred_ops;
+    }
+
   auto *scalar_costs
     = static_cast<const aarch64_vector_costs *> (uncast_scalar_costs);
-  loop_vec_info loop_vinfo = dyn_cast<loop_vec_info> (m_vinfo);
   if (loop_vinfo
       && m_vec_flags
       && aarch64_use_new_vector_costs_p ())
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/cost_model_15.c b/gcc/testsuite/gcc.target/aarch64/sve/cost_model_15.c
new file mode 100644
index 00000000000..b9e6306bb59
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/sve/cost_model_15.c
@@ -0,0 +1,13 @@ 
+/* { dg-options "-Ofast -mtune=neoverse-v1" } */
+
+double f(double *restrict x, double *restrict y, int *restrict z)
+{
+  double res = 0.0;
+  for (int i = 0; i < 100; ++i)
+    res += x[i] * y[z[i]];
+  return res;
+}
+
+/* { dg-final { scan-assembler-times {\tld1sw\tz[0-9]+\.d,} 1 } } */
+/* { dg-final { scan-assembler-times {\tld1d\tz[0-9]+\.d,} 2 } } */
+/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d,} 1 } } */