diff mbox series

[04/10] arm: Fix arm backend-use of (u|s|us)dot_prod patterns.

Message ID 20240710140602.1707875-5-victor.donascimento@arm.com
State New
Headers show
Series Make `dot_prod' a convert-type optab | expand

Commit Message

Victor Do Nascimento July 10, 2024, 2:05 p.m. UTC
gcc/ChangeLog:

	* config/arm/arm-builtins.cc (enum arm_builtins): Add new
	ARM_BUILTIN_* enum values: SDOTV8QI, SDOTV16QI, UDOTV8QI,
	UDOTV16QI, USDOTV8QI, USDOTV16QI.
	(arm_init_dotprod_builtins): New.
	(arm_init_builtins): Add call to `arm_init_dotprod_builtins'.
	(arm_general_gimple_fold_builtin): New.
	* config/arm/arm-protos.h (arm_general_gimple_fold_builtin):
	New prototype.
	* config/arm/arm.cc (arm_gimple_fold_builtin): Add call to
	`arm_general_gimple_fold_builtin'.
	* config/arm/neon.md (<sup>dot_prod<vsi2qi>): Deleted.
	(<sup>dot_prod<mode><vsi2qi>): New.
	(neon_usdot<vsi2qi>): Deleted.
	(neon_usdot<mode><vsi2qi>): New.
---
 gcc/config/arm/arm-builtins.cc       | 95 ++++++++++++++++++++++++++++
 gcc/config/arm/arm-protos.h          |  3 +
 gcc/config/arm/arm.cc                |  1 +
 gcc/config/arm/arm_neon_builtins.def |  3 -
 gcc/config/arm/neon.md               |  4 +-
 5 files changed, 101 insertions(+), 5 deletions(-)
diff mbox series

Patch

diff --git a/gcc/config/arm/arm-builtins.cc b/gcc/config/arm/arm-builtins.cc
index c9d50bf8fbb..b23b6caa063 100644
--- a/gcc/config/arm/arm-builtins.cc
+++ b/gcc/config/arm/arm-builtins.cc
@@ -45,6 +45,8 @@ 
 #include "arm-builtins.h"
 #include "stringpool.h"
 #include "attribs.h"
+#include "basic-block.h"
+#include "gimple.h"
 
 #define SIMD_MAX_BUILTIN_ARGS 7
 
@@ -1298,6 +1300,13 @@  enum arm_builtins
 #define VAR1(T, N, X) \
   ARM_BUILTIN_##N,
 
+  ARM_BUILTIN_NEON_SDOTV8QI,
+  ARM_BUILTIN_NEON_SDOTV16QI,
+  ARM_BUILTIN_NEON_UDOTV8QI,
+  ARM_BUILTIN_NEON_UDOTV16QI,
+  ARM_BUILTIN_NEON_USDOTV8QI,
+  ARM_BUILTIN_NEON_USDOTV16QI,
+
   ARM_BUILTIN_ACLE_BASE,
   ARM_BUILTIN_SAT_IMM_CHECK = ARM_BUILTIN_ACLE_BASE,
 
@@ -2648,6 +2657,60 @@  arm_init_fp16_builtins (void)
 					       "__fp16");
 }
 
+static void
+arm_init_dotprod_builtins (void)
+{
+  tree fndecl = NULL;
+  tree ftype = NULL;
+
+  tree uv8qi = arm_simd_builtin_type (V8QImode, qualifier_unsigned);
+  tree sv8qi = arm_simd_builtin_type (V8QImode, qualifier_none);
+  tree uv16qi = arm_simd_builtin_type (V16QImode, qualifier_unsigned);
+  tree sv16qi = arm_simd_builtin_type (V16QImode, qualifier_none);
+  tree uv2si = arm_simd_builtin_type (V2SImode, qualifier_unsigned);
+  tree sv2si = arm_simd_builtin_type (V2SImode, qualifier_none);
+  tree uv4si = arm_simd_builtin_type (V4SImode, qualifier_unsigned);
+  tree sv4si = arm_simd_builtin_type (V4SImode, qualifier_none);
+
+  struct builtin_decls_data
+  {
+    tree out_type_node;
+    tree in_type1_node;
+    tree in_type2_node;
+    const char *builtin_name;
+    int function_code;
+  };
+
+#define NAME(A) "__builtin_neon_" #A
+#define ENUM(B) ARM_BUILTIN_NEON_##B
+
+  builtin_decls_data bdda[] =
+  {
+    { sv2si, sv8qi,  sv8qi,  NAME (sdotv8qi),	    ENUM (SDOTV8QI)   },
+    { uv2si, uv8qi,  uv8qi,  NAME (udotv8qi_uuuu),  ENUM (UDOTV8QI)   },
+    { sv2si, uv8qi,  sv8qi,  NAME (usdotv8qi_ssus), ENUM (USDOTV8QI)  },
+    { sv4si, sv16qi, sv16qi, NAME (sdotv16qi),	    ENUM (SDOTV16QI)  },
+    { uv4si, uv16qi, uv16qi, NAME (udotv16qi_uuuu),  ENUM (UDOTV16QI)  },
+    { sv4si, uv16qi, sv16qi, NAME (usdotv16qi_ssus), ENUM (USDOTV16QI) },
+  };
+
+#undef NAME
+#undef ENUM
+
+  builtin_decls_data *bdd = bdda;
+  builtin_decls_data *bdd_end = bdd + (ARRAY_SIZE (bdda));
+
+  for (; bdd < bdd_end; bdd++)
+  {
+    ftype = build_function_type_list (bdd->out_type_node, bdd->out_type_node,
+				      bdd->in_type1_node, bdd->in_type2_node,
+				      NULL_TREE);
+    fndecl = arm_general_add_builtin_function (bdd->builtin_name,
+					       ftype, bdd->function_code);
+    arm_builtin_decls[bdd->function_code] = fndecl;
+  }
+}
+
 void
 arm_init_builtins (void)
 {
@@ -2676,6 +2739,7 @@  arm_init_builtins (void)
 	arm_init_neon_builtins ();
       arm_init_vfp_builtins ();
       arm_init_crypto_builtins ();
+      arm_init_dotprod_builtins ();
     }
 
   if (TARGET_CDE)
@@ -2738,6 +2802,37 @@  arm_builtin_decl (unsigned code, bool initialize_p ATTRIBUTE_UNUSED)
     }
 }
 
+/* Try to fold STMT, given that it's a call to the built-in function with
+   subcode FCODE.  Return the new statement on success and null on
+   failure.  */
+gimple *
+arm_general_gimple_fold_builtin (unsigned int fcode, gcall *stmt,
+				 gimple_stmt_iterator *gsi ATTRIBUTE_UNUSED)
+{
+  gimple *new_stmt = NULL;
+  unsigned nargs = gimple_call_num_args (stmt);
+  tree *args = (nargs > 0
+		? gimple_call_arg_ptr (stmt, 0)
+		: &error_mark_node);
+
+  switch (fcode)
+    {
+    case ARM_BUILTIN_NEON_SDOTV8QI:
+    case ARM_BUILTIN_NEON_SDOTV16QI:
+    case ARM_BUILTIN_NEON_UDOTV8QI:
+    case ARM_BUILTIN_NEON_UDOTV16QI:
+    case ARM_BUILTIN_NEON_USDOTV8QI:
+    case ARM_BUILTIN_NEON_USDOTV16QI:
+      new_stmt = gimple_build_assign (gimple_call_lhs (stmt),
+				      DOT_PROD_EXPR, args[1],
+				      args[2], args[0]);
+      break;
+    default:
+      break;
+    }
+  return new_stmt;
+}
+
 /* Errors in the source file can cause expand_expr to return const0_rtx
    where we expect a vector.  To avoid crashing, use one of the vector
    clear instructions.  */
diff --git a/gcc/config/arm/arm-protos.h b/gcc/config/arm/arm-protos.h
index 34d6be76e94..ae8dca3bb4e 100644
--- a/gcc/config/arm/arm-protos.h
+++ b/gcc/config/arm/arm-protos.h
@@ -57,6 +57,9 @@  extern rtx arm_expand_builtin (tree exp, rtx target, rtx subtarget
 extern tree arm_builtin_decl (unsigned code, bool initialize_p
 			      ATTRIBUTE_UNUSED);
 extern void arm_init_builtins (void);
+extern gimple *arm_general_gimple_fold_builtin (unsigned int fcode, gcall *stmt,
+						gimple_stmt_iterator *gsi
+						ATTRIBUTE_UNUSED);
 extern void arm_atomic_assign_expand_fenv (tree *hold, tree *clear, tree *update);
 extern rtx arm_simd_vect_par_cnst_half (machine_mode mode, bool high);
 extern bool arm_simd_check_vect_par_cnst_half_p (rtx op, machine_mode mode,
diff --git a/gcc/config/arm/arm.cc b/gcc/config/arm/arm.cc
index 459b7e648ab..82918849900 100644
--- a/gcc/config/arm/arm.cc
+++ b/gcc/config/arm/arm.cc
@@ -2873,6 +2873,7 @@  arm_gimple_fold_builtin (gimple_stmt_iterator *gsi)
   switch (code & ARM_BUILTIN_CLASS)
     {
     case ARM_BUILTIN_GENERAL:
+      new_stmt = arm_general_gimple_fold_builtin (subcode, stmt, gsi);
       break;
     case ARM_BUILTIN_MVE:
       new_stmt = arm_mve::gimple_fold_builtin (subcode, stmt);
diff --git a/gcc/config/arm/arm_neon_builtins.def b/gcc/config/arm/arm_neon_builtins.def
index 0c5d40b96e5..cf5537ca95d 100644
--- a/gcc/config/arm/arm_neon_builtins.def
+++ b/gcc/config/arm/arm_neon_builtins.def
@@ -349,14 +349,11 @@  VAR13 (STORE1, vst4,
 	v8qi, v4hi, v4hf, v4bf, v2si, v2sf, di, v16qi, v8hi, v8hf, v8bf, v4si, v4sf)
 VAR11 (STORE1LANE, vst4_lane,
 	v8qi, v4hi, v4hf, v2si, v2sf, v8hi, v8hf, v4si, v4sf, v4bf, v8bf)
-VAR2 (TERNOP, sdot, v8qi, v16qi)
-VAR2 (UTERNOP, udot, v8qi, v16qi)
 VAR2 (MAC_LANE, sdot_lane, v8qi, v16qi)
 VAR2 (UMAC_LANE, udot_lane, v8qi, v16qi)
 VAR2 (MAC_LANE, sdot_laneq, v8qi, v16qi)
 VAR2 (UMAC_LANE, udot_laneq, v8qi, v16qi)
 
-VAR2 (USTERNOP, usdot, v8qi, v16qi)
 VAR2 (USMAC_LANE_QUADTUP, usdot_lane, v8qi, v16qi)
 VAR2 (SUMAC_LANE_QUADTUP, sudot_lane, v8qi, v16qi)
 VAR2 (USMAC_LANE_QUADTUP, usdot_laneq, v8qi, v16qi)
diff --git a/gcc/config/arm/neon.md b/gcc/config/arm/neon.md
index fa4a7aeda35..3fbc45b8a8d 100644
--- a/gcc/config/arm/neon.md
+++ b/gcc/config/arm/neon.md
@@ -2989,7 +2989,7 @@  (define_expand "cmul<conj_op><mode>3"
 ;; ...
 ;;
 ;; and so the vectorizer provides r, in which the result has to be accumulated.
-(define_insn "<sup>dot_prod<vsi2qi>"
+(define_insn "<sup>dot_prod<mode><vsi2qi>"
   [(set (match_operand:VCVTI 0 "register_operand" "=w")
 	(plus:VCVTI
 	  (unspec:VCVTI [(match_operand:<VSI2QI> 1 "register_operand" "w")
@@ -3013,7 +3013,7 @@  (define_expand "neon_<sup>dot<vsi2qi>"
 )
 
 ;; These instructions map to the __builtins for the Dot Product operations.
-(define_insn "neon_usdot<vsi2qi>"
+(define_insn "neon_usdot<mode><vsi2qi>"
   [(set (match_operand:VCVTI 0 "register_operand" "=w")
 	(plus:VCVTI
 	  (unspec:VCVTI