diff mbox series

[v4,2/8] ifn: Add else-operand handling.

Message ID 20241107175753.322338-3-rdapp.gcc@gmail.com
State New
Headers show
Series Add maskload else operand. | expand

Commit Message

Robin Dapp Nov. 7, 2024, 5:57 p.m. UTC
From: Robin Dapp <rdapp@ventanamicro.com>

This patch adds else-operand handling to the internal functions.

gcc/ChangeLog:

	* internal-fn.cc (add_mask_and_len_args): Rename...
	(add_mask_else_and_len_args): ...to this and add else handling.
	(expand_partial_load_optab_fn): Use adjusted function.
	(expand_partial_store_optab_fn): Ditto.
	(expand_scatter_store_optab_fn): Ditto.
	(expand_gather_load_optab_fn): Ditto.
	(internal_fn_len_index): Add else handling.
	(internal_fn_else_index): Ditto.
	(internal_fn_mask_index): Ditto.
	(get_supported_else_vals): New function.
	(supported_else_val_p): New function.
	(internal_gather_scatter_fn_supported_p): Add else operand.
	* internal-fn.h (internal_gather_scatter_fn_supported_p): Define
	else constants.
	(MASK_LOAD_ELSE_ZERO): Ditto.
	(MASK_LOAD_ELSE_M1): Ditto.
	(MASK_LOAD_ELSE_UNDEFINED): Ditto.
	(get_supported_else_vals): Declare.
	(supported_else_val_p): Ditto.
---
 gcc/internal-fn.cc | 148 ++++++++++++++++++++++++++++++++++++++-------
 gcc/internal-fn.h  |  13 +++-
 2 files changed, 139 insertions(+), 22 deletions(-)
diff mbox series

Patch

diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
index 1b3fe7be047..6b4f344b40e 100644
--- a/gcc/internal-fn.cc
+++ b/gcc/internal-fn.cc
@@ -333,17 +333,18 @@  get_multi_vector_move (tree array_type, convert_optab optab)
   return convert_optab_handler (optab, imode, vmode);
 }
 
-/* Add mask and len arguments according to the STMT.  */
+/* Add mask, else, and len arguments according to the STMT.  */
 
 static unsigned int
-add_mask_and_len_args (expand_operand *ops, unsigned int opno, gcall *stmt)
+add_mask_else_and_len_args (expand_operand *ops, unsigned int opno, gcall *stmt)
 {
   internal_fn ifn = gimple_call_internal_fn (stmt);
   int len_index = internal_fn_len_index (ifn);
   /* BIAS is always consecutive next of LEN.  */
   int bias_index = len_index + 1;
   int mask_index = internal_fn_mask_index (ifn);
-  /* The order of arguments are always {len,bias,mask}.  */
+
+  /* The order of arguments is always {mask, else, len, bias}.  */
   if (mask_index >= 0)
     {
       tree mask = gimple_call_arg (stmt, mask_index);
@@ -365,6 +366,22 @@  add_mask_and_len_args (expand_operand *ops, unsigned int opno, gcall *stmt)
       create_input_operand (&ops[opno++], mask_rtx,
 			    TYPE_MODE (TREE_TYPE (mask)));
     }
+
+  int els_index = internal_fn_else_index (ifn);
+  if (els_index >= 0)
+    {
+      tree els = gimple_call_arg (stmt, els_index);
+      tree els_type = TREE_TYPE (els);
+      if (TREE_CODE (els) == SSA_NAME
+	  && SSA_NAME_IS_DEFAULT_DEF (els)
+	  && VAR_P (SSA_NAME_VAR (els)))
+	create_undefined_input_operand (&ops[opno++], TYPE_MODE (els_type));
+      else
+	{
+	  rtx els_rtx = expand_normal (els);
+	  create_input_operand (&ops[opno++], els_rtx, TYPE_MODE (els_type));
+	}
+    }
   if (len_index >= 0)
     {
       tree len = gimple_call_arg (stmt, len_index);
@@ -3016,7 +3033,7 @@  static void
 expand_partial_load_optab_fn (internal_fn ifn, gcall *stmt, convert_optab optab)
 {
   int i = 0;
-  class expand_operand ops[5];
+  class expand_operand ops[6];
   tree type, lhs, rhs, maskt;
   rtx mem, target;
   insn_code icode;
@@ -3046,7 +3063,7 @@  expand_partial_load_optab_fn (internal_fn ifn, gcall *stmt, convert_optab optab)
   target = expand_expr (lhs, NULL_RTX, VOIDmode, EXPAND_WRITE);
   create_call_lhs_operand (&ops[i++], target, TYPE_MODE (type));
   create_fixed_operand (&ops[i++], mem);
-  i = add_mask_and_len_args (ops, i, stmt);
+  i = add_mask_else_and_len_args (ops, i, stmt);
   expand_insn (icode, i, ops);
 
   assign_call_lhs (lhs, target, &ops[0]);
@@ -3092,7 +3109,7 @@  expand_partial_store_optab_fn (internal_fn ifn, gcall *stmt, convert_optab optab
   reg = expand_normal (rhs);
   create_fixed_operand (&ops[i++], mem);
   create_input_operand (&ops[i++], reg, TYPE_MODE (type));
-  i = add_mask_and_len_args (ops, i, stmt);
+  i = add_mask_else_and_len_args (ops, i, stmt);
   expand_insn (icode, i, ops);
 }
 
@@ -3678,7 +3695,7 @@  expand_scatter_store_optab_fn (internal_fn, gcall *stmt, direct_optab optab)
   create_integer_operand (&ops[i++], TYPE_UNSIGNED (TREE_TYPE (offset)));
   create_integer_operand (&ops[i++], scale_int);
   create_input_operand (&ops[i++], rhs_rtx, TYPE_MODE (TREE_TYPE (rhs)));
-  i = add_mask_and_len_args (ops, i, stmt);
+  i = add_mask_else_and_len_args (ops, i, stmt);
 
   insn_code icode = convert_optab_handler (optab, TYPE_MODE (TREE_TYPE (rhs)),
 					   TYPE_MODE (TREE_TYPE (offset)));
@@ -3701,13 +3718,13 @@  expand_gather_load_optab_fn (internal_fn, gcall *stmt, direct_optab optab)
   HOST_WIDE_INT scale_int = tree_to_shwi (scale);
 
   int i = 0;
-  class expand_operand ops[8];
+  class expand_operand ops[9];
   create_call_lhs_operand (&ops[i++], lhs_rtx, TYPE_MODE (TREE_TYPE (lhs)));
   create_address_operand (&ops[i++], base_rtx);
   create_input_operand (&ops[i++], offset_rtx, TYPE_MODE (TREE_TYPE (offset)));
   create_integer_operand (&ops[i++], TYPE_UNSIGNED (TREE_TYPE (offset)));
   create_integer_operand (&ops[i++], scale_int);
-  i = add_mask_and_len_args (ops, i, stmt);
+  i = add_mask_else_and_len_args (ops, i, stmt);
   insn_code icode = convert_optab_handler (optab, TYPE_MODE (TREE_TYPE (lhs)),
 					   TYPE_MODE (TREE_TYPE (offset)));
   expand_insn (icode, i, ops);
@@ -3729,14 +3746,14 @@  expand_strided_load_optab_fn (ATTRIBUTE_UNUSED internal_fn, gcall *stmt,
   rtx stride_rtx = expand_normal (stride);
 
   unsigned i = 0;
-  class expand_operand ops[6];
+  class expand_operand ops[7];
   machine_mode mode = TYPE_MODE (TREE_TYPE (lhs));
 
   create_output_operand (&ops[i++], lhs_rtx, mode);
   create_address_operand (&ops[i++], base_rtx);
   create_address_operand (&ops[i++], stride_rtx);
 
-  i = add_mask_and_len_args (ops, i, stmt);
+  i = add_mask_else_and_len_args (ops, i, stmt);
   expand_insn (direct_optab_handler (optab, mode), i, ops);
 
   if (!rtx_equal_p (lhs_rtx, ops[0].value))
@@ -3768,7 +3785,7 @@  expand_strided_store_optab_fn (ATTRIBUTE_UNUSED internal_fn, gcall *stmt,
   create_address_operand (&ops[i++], stride_rtx);
   create_input_operand (&ops[i++], rhs_rtx, mode);
 
-  i = add_mask_and_len_args (ops, i, stmt);
+  i = add_mask_else_and_len_args (ops, i, stmt);
   expand_insn (direct_optab_handler (optab, mode), i, ops);
 }
 
@@ -4662,6 +4679,18 @@  get_len_internal_fn (internal_fn fn)
   case IFN_COND_##NAME:                                                        \
     return IFN_COND_LEN_##NAME;
 #include "internal-fn.def"
+    default:
+      break;
+    }
+
+  switch (fn)
+    {
+    case IFN_MASK_LOAD:
+      return IFN_MASK_LEN_LOAD;
+    case IFN_MASK_LOAD_LANES:
+      return IFN_MASK_LEN_LOAD_LANES;
+    case IFN_MASK_GATHER_LOAD:
+      return IFN_MASK_LEN_GATHER_LOAD;
     default:
       return IFN_LAST;
     }
@@ -4847,8 +4876,13 @@  internal_fn_len_index (internal_fn fn)
     case IFN_LEN_STORE:
       return 2;
 
-    case IFN_MASK_LEN_GATHER_LOAD:
     case IFN_MASK_LEN_SCATTER_STORE:
+    case IFN_MASK_LEN_STRIDED_LOAD:
+      return 5;
+
+    case IFN_MASK_LEN_GATHER_LOAD:
+      return 6;
+
     case IFN_COND_LEN_FMA:
     case IFN_COND_LEN_FMS:
     case IFN_COND_LEN_FNMA:
@@ -4870,18 +4904,19 @@  internal_fn_len_index (internal_fn fn)
     case IFN_COND_LEN_XOR:
     case IFN_COND_LEN_SHL:
     case IFN_COND_LEN_SHR:
-    case IFN_MASK_LEN_STRIDED_LOAD:
     case IFN_MASK_LEN_STRIDED_STORE:
       return 4;
 
     case IFN_COND_LEN_NEG:
-    case IFN_MASK_LEN_LOAD:
     case IFN_MASK_LEN_STORE:
-    case IFN_MASK_LEN_LOAD_LANES:
     case IFN_MASK_LEN_STORE_LANES:
     case IFN_VCOND_MASK_LEN:
       return 3;
 
+    case IFN_MASK_LEN_LOAD:
+    case IFN_MASK_LEN_LOAD_LANES:
+      return 4;
+
     default:
       return -1;
     }
@@ -4931,6 +4966,12 @@  internal_fn_else_index (internal_fn fn)
     case IFN_COND_LEN_SHR:
       return 3;
 
+    case IFN_MASK_LOAD:
+    case IFN_MASK_LEN_LOAD:
+    case IFN_MASK_LOAD_LANES:
+    case IFN_MASK_LEN_LOAD_LANES:
+      return 3;
+
     case IFN_COND_FMA:
     case IFN_COND_FMS:
     case IFN_COND_FNMA:
@@ -4939,8 +4980,13 @@  internal_fn_else_index (internal_fn fn)
     case IFN_COND_LEN_FMS:
     case IFN_COND_LEN_FNMA:
     case IFN_COND_LEN_FNMS:
+    case IFN_MASK_LEN_STRIDED_LOAD:
       return 4;
 
+    case IFN_MASK_GATHER_LOAD:
+    case IFN_MASK_LEN_GATHER_LOAD:
+      return 5;
+
     default:
       return -1;
     }
@@ -4976,6 +5022,7 @@  internal_fn_mask_index (internal_fn fn)
     case IFN_MASK_LEN_SCATTER_STORE:
       return 4;
 
+    case IFN_VCOND_MASK:
     case IFN_VCOND_MASK_LEN:
       return 0;
 
@@ -5015,6 +5062,52 @@  internal_fn_stored_value_index (internal_fn fn)
     }
 }
 
+
+/* Store all supported else values for the optab referred to by ICODE
+   in ELSE_VALS.  The index of the else operand must be specified in
+   ELSE_INDEX.  */
+
+void
+get_supported_else_vals (enum insn_code icode, unsigned else_index,
+			 vec<int> &else_vals)
+{
+  const struct insn_data_d *data = &insn_data[icode];
+  if ((char)else_index >= data->n_operands)
+    return;
+
+  machine_mode else_mode = data->operand[else_index].mode;
+
+  else_vals.truncate (0);
+
+  /* For now we only support else values of 0, -1, and "undefined".  */
+  if (insn_operand_matches (icode, else_index, CONST0_RTX (else_mode)))
+    else_vals.safe_push (MASK_LOAD_ELSE_ZERO);
+
+  if (insn_operand_matches (icode, else_index, gen_rtx_SCRATCH (else_mode)))
+    else_vals.safe_push (MASK_LOAD_ELSE_UNDEFINED);
+
+  if (GET_MODE_CLASS (else_mode) == MODE_VECTOR_INT
+      && insn_operand_matches (icode, else_index, CONSTM1_RTX (else_mode)))
+    else_vals.safe_push (MASK_LOAD_ELSE_M1);
+}
+
+/* Return true if the else value ELSE_VAL (one of MASK_LOAD_ELSE_ZERO,
+   MASK_LOAD_ELSE_M1, and MASK_LOAD_ELSE_UNDEFINED) is valid fo the optab
+   referred to by ICODE.  The index of the else operand must be specified
+   in ELSE_INDEX.  */
+
+bool
+supported_else_val_p (enum insn_code icode, unsigned else_index, int else_val)
+{
+  if (else_val != MASK_LOAD_ELSE_ZERO && else_val != MASK_LOAD_ELSE_M1
+      && else_val != MASK_LOAD_ELSE_UNDEFINED)
+    gcc_unreachable ();
+
+  auto_vec<int> else_vals;
+  get_supported_else_vals (icode, else_index, else_vals);
+  return else_vals.contains (else_val);
+}
+
 /* Return true if the target supports gather load or scatter store function
    IFN.  For loads, VECTOR_TYPE is the vector type of the load result,
    while for stores it is the vector type of the stored data argument.
@@ -5022,12 +5115,15 @@  internal_fn_stored_value_index (internal_fn fn)
    or stored.  OFFSET_VECTOR_TYPE is the vector type that holds the
    offset from the shared base address of each loaded or stored element.
    SCALE is the amount by which these offsets should be multiplied
-   *after* they have been extended to address width.  */
+   *after* they have been extended to address width.
+   If the target supports the gather load the supported else values
+   will be added to the vector ELSVAL points to if it is nonzero.  */
 
 bool
 internal_gather_scatter_fn_supported_p (internal_fn ifn, tree vector_type,
 					tree memory_element_type,
-					tree offset_vector_type, int scale)
+					tree offset_vector_type, int scale,
+					vec<int> *elsvals)
 {
   if (!tree_int_cst_equal (TYPE_SIZE (TREE_TYPE (vector_type)),
 			   TYPE_SIZE (memory_element_type)))
@@ -5040,9 +5136,19 @@  internal_gather_scatter_fn_supported_p (internal_fn ifn, tree vector_type,
 					   TYPE_MODE (offset_vector_type));
   int output_ops = internal_load_fn_p (ifn) ? 1 : 0;
   bool unsigned_p = TYPE_UNSIGNED (TREE_TYPE (offset_vector_type));
-  return (icode != CODE_FOR_nothing
-	  && insn_operand_matches (icode, 2 + output_ops, GEN_INT (unsigned_p))
-	  && insn_operand_matches (icode, 3 + output_ops, GEN_INT (scale)));
+  bool ok = icode != CODE_FOR_nothing
+    && insn_operand_matches (icode, 2 + output_ops, GEN_INT (unsigned_p))
+    && insn_operand_matches (icode, 3 + output_ops, GEN_INT (scale));
+
+  /* For gather the optab's operand indices do not match the IFN's because
+     the latter does not have the extension operand (operand 3).  It is
+     implicitly added during expansion so we use the IFN's else index + 1.
+     */
+  if (ok && elsvals)
+    get_supported_else_vals
+      (icode, internal_fn_else_index (IFN_MASK_GATHER_LOAD) + 1, *elsvals);
+
+  return ok;
 }
 
 /* Return true if the target supports IFN_CHECK_{RAW,WAR}_PTRS function IFN
diff --git a/gcc/internal-fn.h b/gcc/internal-fn.h
index 2785a5a95a2..37fbc60f6dd 100644
--- a/gcc/internal-fn.h
+++ b/gcc/internal-fn.h
@@ -240,9 +240,20 @@  extern int internal_fn_len_index (internal_fn);
 extern int internal_fn_else_index (internal_fn);
 extern int internal_fn_stored_value_index (internal_fn);
 extern bool internal_gather_scatter_fn_supported_p (internal_fn, tree,
-						    tree, tree, int);
+						    tree, tree, int,
+						    vec<int> * = nullptr);
 extern bool internal_check_ptrs_fn_supported_p (internal_fn, tree,
 						poly_uint64, unsigned int);
+
+/* Integer constants representing which else value is supported for masked load
+   functions.  */
+#define MASK_LOAD_ELSE_ZERO -1
+#define MASK_LOAD_ELSE_M1 -2
+#define MASK_LOAD_ELSE_UNDEFINED -3
+
+extern void get_supported_else_vals (enum insn_code, unsigned, vec<int> &);
+extern bool supported_else_val_p (enum insn_code, unsigned, int);
+
 #define VECT_PARTIAL_BIAS_UNSUPPORTED 127
 
 extern signed char internal_len_load_store_bias (internal_fn ifn,