diff mbox series

[v2,5/8] aarch64: Add masked-load else operands.

Message ID 20241018142220.173482-6-rdapp@ventanamicro.com
State New
Headers show
Series Add maskload else operand. | expand

Commit Message

Robin Dapp Oct. 18, 2024, 2:22 p.m. UTC
This adds zero else operands to masked loads and their intrinsics.
I needed to adjust more than initially thought because we rely on
combine for several instructions and a change in a "base" pattern
needs to propagate to all those.

For the lack of a better idea I used a function call property to specify
whether a builtin needs an else operand or not.  Somebody with better
knowledge of the aarch64 target can surely improve that.

gcc/ChangeLog:

	* config/aarch64/aarch64-sve-builtins-base.cc: Add else
	handling.
	* config/aarch64/aarch64-sve-builtins.cc (function_expander::use_contiguous_load_insn):
	Ditto.
	* config/aarch64/aarch64-sve-builtins.h: Add "has else".
	* config/aarch64/aarch64-sve.md (*aarch64_load<SVE_PRED_LOAD:pred_load>
	_<ANY_EXTEND:optab>_mov<SVE_HSDI:mode><SVE_PARTIAL_I:mode>):
	Add else operands.
	* config/aarch64/aarch64-sve2.md: Ditto.
	* config/aarch64/predicates.md (aarch64_maskload_else_operand):
	Add zero else operand.
---
 .../aarch64/aarch64-sve-builtins-base.cc      | 58 ++++++++++++++-----
 gcc/config/aarch64/aarch64-sve-builtins.cc    |  5 ++
 gcc/config/aarch64/aarch64-sve-builtins.h     |  1 +
 gcc/config/aarch64/aarch64-sve.md             | 47 +++++++++++++--
 gcc/config/aarch64/aarch64-sve2.md            |  3 +-
 gcc/config/aarch64/predicates.md              |  4 ++
 6 files changed, 98 insertions(+), 20 deletions(-)
diff mbox series

Patch

diff --git a/gcc/config/aarch64/aarch64-sve-builtins-base.cc b/gcc/config/aarch64/aarch64-sve-builtins-base.cc
index 1c17149e1f0..08d2fb796dd 100644
--- a/gcc/config/aarch64/aarch64-sve-builtins-base.cc
+++ b/gcc/config/aarch64/aarch64-sve-builtins-base.cc
@@ -1476,7 +1476,7 @@  public:
   unsigned int
   call_properties (const function_instance &) const override
   {
-    return CP_READ_MEMORY;
+    return CP_READ_MEMORY | CP_HAS_ELSE;
   }
 
   gimple *
@@ -1491,11 +1491,12 @@  public:
     gimple_seq stmts = NULL;
     tree pred = f.convert_pred (stmts, vectype, 0);
     tree base = f.fold_contiguous_base (stmts, vectype);
+    tree els = build_zero_cst (vectype);
     gsi_insert_seq_before (f.gsi, stmts, GSI_SAME_STMT);
 
     tree cookie = f.load_store_cookie (TREE_TYPE (vectype));
-    gcall *new_call = gimple_build_call_internal (IFN_MASK_LOAD, 3,
-						  base, cookie, pred);
+    gcall *new_call = gimple_build_call_internal (IFN_MASK_LOAD, 4,
+						  base, cookie, pred, els);
     gimple_call_set_lhs (new_call, f.lhs);
     return new_call;
   }
@@ -1505,10 +1506,16 @@  public:
   {
     insn_code icode;
     if (e.vectors_per_tuple () == 1)
-      icode = convert_optab_handler (maskload_optab,
-				     e.vector_mode (0), e.gp_mode (0));
+      {
+	icode = convert_optab_handler (maskload_optab,
+				       e.vector_mode (0), e.gp_mode (0));
+	e.args.quick_push (CONST0_RTX (e.vector_mode (0)));
+      }
     else
-      icode = code_for_aarch64 (UNSPEC_LD1_COUNT, e.tuple_mode (0));
+      {
+	icode = code_for_aarch64 (UNSPEC_LD1_COUNT, e.tuple_mode (0));
+	e.args.quick_push (CONST0_RTX (e.tuple_mode (0)));
+      }
     return e.use_contiguous_load_insn (icode);
   }
 };
@@ -1519,12 +1526,20 @@  class svld1_extend_impl : public extending_load
 public:
   using extending_load::extending_load;
 
+  unsigned int
+  call_properties (const function_instance &) const override
+  {
+    return CP_READ_MEMORY | CP_HAS_ELSE;
+  }
+
   rtx
   expand (function_expander &e) const override
   {
     insn_code icode = code_for_aarch64_load (UNSPEC_LD1_SVE, extend_rtx_code (),
 					     e.vector_mode (0),
 					     e.memory_vector_mode ());
+    /* Add the else operand.  */
+    e.args.quick_push (CONST0_RTX (e.vector_mode (1)));
     return e.use_contiguous_load_insn (icode);
   }
 };
@@ -1535,7 +1550,7 @@  public:
   unsigned int
   call_properties (const function_instance &) const override
   {
-    return CP_READ_MEMORY;
+    return CP_READ_MEMORY | CP_HAS_ELSE;
   }
 
   rtx
@@ -1544,6 +1559,8 @@  public:
     e.prepare_gather_address_operands (1);
     /* Put the predicate last, as required by mask_gather_load_optab.  */
     e.rotate_inputs_left (0, 5);
+    /* Add the else operand.  */
+    e.args.quick_push (CONST0_RTX (e.vector_mode (0)));
     machine_mode mem_mode = e.memory_vector_mode ();
     machine_mode int_mode = aarch64_sve_int_mode (mem_mode);
     insn_code icode = convert_optab_handler (mask_gather_load_optab,
@@ -1567,6 +1584,8 @@  public:
     e.rotate_inputs_left (0, 5);
     /* Add a constant predicate for the extension rtx.  */
     e.args.quick_push (CONSTM1_RTX (VNx16BImode));
+    /* Add the else operand.  */
+    e.args.quick_push (CONST0_RTX (e.vector_mode (1)));
     insn_code icode = code_for_aarch64_gather_load (extend_rtx_code (),
 						    e.vector_mode (0),
 						    e.memory_vector_mode ());
@@ -1697,7 +1716,7 @@  public:
   unsigned int
   call_properties (const function_instance &) const override
   {
-    return CP_READ_MEMORY;
+    return CP_READ_MEMORY | CP_HAS_ELSE;
   }
 
   gimple *
@@ -1709,6 +1728,7 @@  public:
     /* Get the predicate and base pointer.  */
     gimple_seq stmts = NULL;
     tree pred = f.convert_pred (stmts, vectype, 0);
+    tree els = build_zero_cst (vectype);
     tree base = f.fold_contiguous_base (stmts, vectype);
     gsi_insert_seq_before (f.gsi, stmts, GSI_SAME_STMT);
 
@@ -1727,8 +1747,8 @@  public:
 
     /* Emit the load itself.  */
     tree cookie = f.load_store_cookie (TREE_TYPE (vectype));
-    gcall *new_call = gimple_build_call_internal (IFN_MASK_LOAD_LANES, 3,
-						  base, cookie, pred);
+    gcall *new_call = gimple_build_call_internal (IFN_MASK_LOAD_LANES, 4,
+						  base, cookie, pred, els);
     gimple_call_set_lhs (new_call, lhs_array);
     gsi_insert_after (f.gsi, new_call, GSI_SAME_STMT);
 
@@ -1741,6 +1761,7 @@  public:
     machine_mode tuple_mode = e.result_mode ();
     insn_code icode = convert_optab_handler (vec_mask_load_lanes_optab,
 					     tuple_mode, e.vector_mode (0));
+    e.args.quick_push (CONST0_RTX (e.vector_mode (0)));
     return e.use_contiguous_load_insn (icode);
   }
 };
@@ -1802,16 +1823,23 @@  public:
   unsigned int
   call_properties (const function_instance &) const override
   {
-    return CP_READ_MEMORY;
+    return CP_READ_MEMORY | CP_HAS_ELSE;
   }
 
   rtx
   expand (function_expander &e) const override
   {
-    insn_code icode = (e.vectors_per_tuple () == 1
-		       ? code_for_aarch64_ldnt1 (e.vector_mode (0))
-		       : code_for_aarch64 (UNSPEC_LDNT1_COUNT,
-					   e.tuple_mode (0)));
+    insn_code icode;
+    if (e.vectors_per_tuple () == 1)
+      {
+	icode = code_for_aarch64_ldnt1 (e.vector_mode (0));
+	e.args.quick_push (CONST0_RTX (e.vector_mode (0)));
+      }
+    else
+      {
+	icode = code_for_aarch64 (UNSPEC_LDNT1_COUNT, e.tuple_mode (0));
+	e.args.quick_push (CONST0_RTX (e.tuple_mode (0)));
+      }
     return e.use_contiguous_load_insn (icode);
   }
 };
diff --git a/gcc/config/aarch64/aarch64-sve-builtins.cc b/gcc/config/aarch64/aarch64-sve-builtins.cc
index e7c703c987e..7214f1f5a3e 100644
--- a/gcc/config/aarch64/aarch64-sve-builtins.cc
+++ b/gcc/config/aarch64/aarch64-sve-builtins.cc
@@ -4207,6 +4207,11 @@  function_expander::use_contiguous_load_insn (insn_code icode)
   add_input_operand (icode, args[0]);
   if (GET_MODE_UNIT_BITSIZE (mem_mode) < type_suffix (0).element_bits)
     add_input_operand (icode, CONSTM1_RTX (VNx16BImode));
+
+  /* If we have an else operand, add it.  */
+  if (call_properties () & CP_HAS_ELSE)
+    add_input_operand (icode, args.last ());
+
   return generate_insn (icode);
 }
 
diff --git a/gcc/config/aarch64/aarch64-sve-builtins.h b/gcc/config/aarch64/aarch64-sve-builtins.h
index 645e56badbe..6cda8bd8a8c 100644
--- a/gcc/config/aarch64/aarch64-sve-builtins.h
+++ b/gcc/config/aarch64/aarch64-sve-builtins.h
@@ -103,6 +103,7 @@  const unsigned int CP_READ_ZA = 1U << 7;
 const unsigned int CP_WRITE_ZA = 1U << 8;
 const unsigned int CP_READ_ZT0 = 1U << 9;
 const unsigned int CP_WRITE_ZT0 = 1U << 10;
+const unsigned int CP_HAS_ELSE = 1U << 11;
 
 /* Enumerates the SVE predicate and (data) vector types, together called
    "vector types" for brevity.  */
diff --git a/gcc/config/aarch64/aarch64-sve.md b/gcc/config/aarch64/aarch64-sve.md
index 06bd3e4bb2c..1e12fa3c982 100644
--- a/gcc/config/aarch64/aarch64-sve.md
+++ b/gcc/config/aarch64/aarch64-sve.md
@@ -1291,7 +1291,8 @@  (define_insn "maskload<mode><vpred>"
   [(set (match_operand:SVE_ALL 0 "register_operand" "=w")
 	(unspec:SVE_ALL
 	  [(match_operand:<VPRED> 2 "register_operand" "Upl")
-	   (match_operand:SVE_ALL 1 "memory_operand" "m")]
+	   (match_operand:SVE_ALL 1 "memory_operand" "m")
+	   (match_operand:SVE_ALL 3 "aarch64_maskload_else_operand")]
 	  UNSPEC_LD1_SVE))]
   "TARGET_SVE"
   "ld1<Vesize>\t%0.<Vctype>, %2/z, %1"
@@ -1302,11 +1303,14 @@  (define_expand "vec_load_lanes<mode><vsingle>"
   [(set (match_operand:SVE_STRUCT 0 "register_operand")
 	(unspec:SVE_STRUCT
 	  [(match_dup 2)
-	   (match_operand:SVE_STRUCT 1 "memory_operand")]
+	   (match_operand:SVE_STRUCT 1 "memory_operand")
+	   (match_dup 3)
+	  ]
 	  UNSPEC_LDN))]
   "TARGET_SVE"
   {
     operands[2] = aarch64_ptrue_reg (<VPRED>mode);
+    operands[3] = CONST0_RTX (<MODE>mode);
   }
 )
 
@@ -1315,7 +1319,8 @@  (define_insn "vec_mask_load_lanes<mode><vsingle>"
   [(set (match_operand:SVE_STRUCT 0 "register_operand" "=w")
 	(unspec:SVE_STRUCT
 	  [(match_operand:<VPRED> 2 "register_operand" "Upl")
-	   (match_operand:SVE_STRUCT 1 "memory_operand" "m")]
+	   (match_operand:SVE_STRUCT 1 "memory_operand" "m")
+	   (match_operand 3 "aarch64_maskload_else_operand")]
 	  UNSPEC_LDN))]
   "TARGET_SVE"
   "ld<vector_count><Vesize>\t%0, %2/z, %1"
@@ -1335,6 +1340,27 @@  (define_insn "vec_mask_load_lanes<mode><vsingle>"
 
 ;; Predicated load and extend, with 8 elements per 128-bit block.
 (define_insn_and_rewrite "@aarch64_load<SVE_PRED_LOAD:pred_load>_<ANY_EXTEND:optab><SVE_HSDI:mode><SVE_PARTIAL_I:mode>"
+  [(set (match_operand:SVE_HSDI 0 "register_operand" "=w")
+	(unspec:SVE_HSDI
+	  [(match_operand:<SVE_HSDI:VPRED> 3 "general_operand" "UplDnm")
+	   (ANY_EXTEND:SVE_HSDI
+	     (unspec:SVE_PARTIAL_I
+	       [(match_operand:<SVE_PARTIAL_I:VPRED> 2 "register_operand" "Upl")
+		(match_operand:SVE_PARTIAL_I 1 "memory_operand" "m")
+		(match_operand:SVE_PARTIAL_I 4 "aarch64_maskload_else_operand")]
+	       SVE_PRED_LOAD))]
+	  UNSPEC_PRED_X))]
+  "TARGET_SVE && (~<SVE_HSDI:narrower_mask> & <SVE_PARTIAL_I:self_mask>) == 0"
+  "ld1<ANY_EXTEND:s><SVE_PARTIAL_I:Vesize>\t%0.<SVE_HSDI:Vctype>, %2/z, %1"
+  "&& !CONSTANT_P (operands[3])"
+  {
+    operands[3] = CONSTM1_RTX (<SVE_HSDI:VPRED>mode);
+  }
+)
+
+;; Same as above without the maskload_else_operand to still allow combine to
+;; match a sign-extended pred_mov pattern.
+(define_insn_and_rewrite "*aarch64_load<SVE_PRED_LOAD:pred_load>_<ANY_EXTEND:optab>_mov<SVE_HSDI:mode><SVE_PARTIAL_I:mode>"
   [(set (match_operand:SVE_HSDI 0 "register_operand" "=w")
 	(unspec:SVE_HSDI
 	  [(match_operand:<SVE_HSDI:VPRED> 3 "general_operand" "UplDnm")
@@ -1433,7 +1459,8 @@  (define_insn "@aarch64_ldnt1<mode>"
   [(set (match_operand:SVE_FULL 0 "register_operand" "=w")
 	(unspec:SVE_FULL
 	  [(match_operand:<VPRED> 2 "register_operand" "Upl")
-	   (match_operand:SVE_FULL 1 "memory_operand" "m")]
+	   (match_operand:SVE_FULL 1 "memory_operand" "m")
+	   (match_operand:SVE_FULL 3 "aarch64_maskload_else_operand")]
 	  UNSPEC_LDNT1_SVE))]
   "TARGET_SVE"
   "ldnt1<Vesize>\t%0.<Vetype>, %2/z, %1"
@@ -1456,11 +1483,13 @@  (define_expand "gather_load<mode><v_int_container>"
 	   (match_operand:<V_INT_CONTAINER> 2 "register_operand")
 	   (match_operand:DI 3 "const_int_operand")
 	   (match_operand:DI 4 "aarch64_gather_scale_operand_<Vesize>")
+	   (match_dup 6)
 	   (mem:BLK (scratch))]
 	  UNSPEC_LD1_GATHER))]
   "TARGET_SVE && TARGET_NON_STREAMING"
   {
     operands[5] = aarch64_ptrue_reg (<VPRED>mode);
+    operands[6] = CONST0_RTX (<MODE>mode);
   }
 )
 
@@ -1474,6 +1503,7 @@  (define_insn "mask_gather_load<mode><v_int_container>"
 	   (match_operand:VNx4SI 2 "register_operand")
 	   (match_operand:DI 3 "const_int_operand")
 	   (match_operand:DI 4 "aarch64_gather_scale_operand_<Vesize>")
+	   (match_operand:SVE_4 6 "aarch64_maskload_else_operand")
 	   (mem:BLK (scratch))]
 	  UNSPEC_LD1_GATHER))]
   "TARGET_SVE && TARGET_NON_STREAMING"
@@ -1503,6 +1533,7 @@  (define_insn "mask_gather_load<mode><v_int_container>"
 	   (match_operand:VNx2DI 2 "register_operand")
 	   (match_operand:DI 3 "const_int_operand")
 	   (match_operand:DI 4 "aarch64_gather_scale_operand_<Vesize>")
+	   (match_operand:SVE_2 6 "aarch64_maskload_else_operand")
 	   (mem:BLK (scratch))]
 	  UNSPEC_LD1_GATHER))]
   "TARGET_SVE && TARGET_NON_STREAMING"
@@ -1531,6 +1562,7 @@  (define_insn_and_rewrite "*mask_gather_load<mode><v_int_container>_<su>xtw_unpac
 	     UNSPEC_PRED_X)
 	   (match_operand:DI 3 "const_int_operand")
 	   (match_operand:DI 4 "aarch64_gather_scale_operand_<Vesize>")
+	   (match_operand:SVE_2 7 "aarch64_maskload_else_operand")
 	   (mem:BLK (scratch))]
 	  UNSPEC_LD1_GATHER))]
   "TARGET_SVE && TARGET_NON_STREAMING"
@@ -1561,6 +1593,7 @@  (define_insn_and_rewrite "*mask_gather_load<mode><v_int_container>_sxtw"
 	     UNSPEC_PRED_X)
 	   (match_operand:DI 3 "const_int_operand")
 	   (match_operand:DI 4 "aarch64_gather_scale_operand_<Vesize>")
+	   (match_operand:SVE_2 7 "aarch64_maskload_else_operand")
 	   (mem:BLK (scratch))]
 	  UNSPEC_LD1_GATHER))]
   "TARGET_SVE && TARGET_NON_STREAMING"
@@ -1588,6 +1621,7 @@  (define_insn "*mask_gather_load<mode><v_int_container>_uxtw"
 	     (match_operand:VNx2DI 6 "aarch64_sve_uxtw_immediate"))
 	   (match_operand:DI 3 "const_int_operand")
 	   (match_operand:DI 4 "aarch64_gather_scale_operand_<Vesize>")
+	   (match_operand:SVE_2 7 "aarch64_maskload_else_operand")
 	   (mem:BLK (scratch))]
 	  UNSPEC_LD1_GATHER))]
   "TARGET_SVE && TARGET_NON_STREAMING"
@@ -1624,6 +1658,7 @@  (define_insn_and_rewrite "@aarch64_gather_load_<ANY_EXTEND:optab><SVE_4HSI:mode>
 		(match_operand:VNx4SI 2 "register_operand")
 		(match_operand:DI 3 "const_int_operand")
 		(match_operand:DI 4 "aarch64_gather_scale_operand_<SVE_4BHI:Vesize>")
+		(match_operand:SVE_4BHI 7 "aarch64_maskload_else_operand")
 		(mem:BLK (scratch))]
 	       UNSPEC_LD1_GATHER))]
 	  UNSPEC_PRED_X))]
@@ -1663,6 +1698,7 @@  (define_insn_and_rewrite "@aarch64_gather_load_<ANY_EXTEND:optab><SVE_2HSDI:mode
 		(match_operand:VNx2DI 2 "register_operand")
 		(match_operand:DI 3 "const_int_operand")
 		(match_operand:DI 4 "aarch64_gather_scale_operand_<SVE_2BHSI:Vesize>")
+		(match_operand:SVE_2BHSI 7 "aarch64_maskload_else_operand")
 		(mem:BLK (scratch))]
 	       UNSPEC_LD1_GATHER))]
 	  UNSPEC_PRED_X))]
@@ -1701,6 +1737,7 @@  (define_insn_and_rewrite "*aarch64_gather_load_<ANY_EXTEND:optab><SVE_2HSDI:mode
 		  UNSPEC_PRED_X)
 		(match_operand:DI 3 "const_int_operand")
 		(match_operand:DI 4 "aarch64_gather_scale_operand_<SVE_2BHSI:Vesize>")
+		(match_operand:SVE_2BHSI 8 "aarch64_maskload_else_operand")
 		(mem:BLK (scratch))]
 	       UNSPEC_LD1_GATHER))]
 	  UNSPEC_PRED_X))]
@@ -1738,6 +1775,7 @@  (define_insn_and_rewrite "*aarch64_gather_load_<ANY_EXTEND:optab><SVE_2HSDI:mode
 		  UNSPEC_PRED_X)
 		(match_operand:DI 3 "const_int_operand")
 		(match_operand:DI 4 "aarch64_gather_scale_operand_<SVE_2BHSI:Vesize>")
+		(match_operand:SVE_2BHSI 8 "aarch64_maskload_else_operand")
 		(mem:BLK (scratch))]
 	       UNSPEC_LD1_GATHER))]
 	  UNSPEC_PRED_X))]
@@ -1772,6 +1810,7 @@  (define_insn_and_rewrite "*aarch64_gather_load_<ANY_EXTEND:optab><SVE_2HSDI:mode
 		  (match_operand:VNx2DI 6 "aarch64_sve_uxtw_immediate"))
 		(match_operand:DI 3 "const_int_operand")
 		(match_operand:DI 4 "aarch64_gather_scale_operand_<SVE_2BHSI:Vesize>")
+		(match_operand:SVE_2BHSI 8 "aarch64_maskload_else_operand")
 		(mem:BLK (scratch))]
 	       UNSPEC_LD1_GATHER))]
 	  UNSPEC_PRED_X))]
diff --git a/gcc/config/aarch64/aarch64-sve2.md b/gcc/config/aarch64/aarch64-sve2.md
index 5f2697c3179..22e8632af80 100644
--- a/gcc/config/aarch64/aarch64-sve2.md
+++ b/gcc/config/aarch64/aarch64-sve2.md
@@ -138,7 +138,8 @@  (define_insn "@aarch64_<optab><mode>"
   [(set (match_operand:SVE_FULLx24 0 "aligned_register_operand" "=Uw<vector_count>")
 	(unspec:SVE_FULLx24
 	  [(match_operand:VNx16BI 2 "register_operand" "Uph")
-	   (match_operand:SVE_FULLx24 1 "memory_operand" "m")]
+	   (match_operand:SVE_FULLx24 1 "memory_operand" "m")
+	   (match_operand:SVE_FULLx24 3 "aarch64_maskload_else_operand")]
 	  LD1_COUNT))]
   "TARGET_STREAMING_SME2"
   "<optab><Vesize>\t%0, %K2/z, %1"
diff --git a/gcc/config/aarch64/predicates.md b/gcc/config/aarch64/predicates.md
index 8f3aab2272c..744f36ff67d 100644
--- a/gcc/config/aarch64/predicates.md
+++ b/gcc/config/aarch64/predicates.md
@@ -1069,3 +1069,7 @@  (define_predicate "aarch64_granule16_simm9"
   (and (match_code "const_int")
        (match_test "IN_RANGE (INTVAL (op),  -4096, 4080)
 		    && !(INTVAL (op) & 0xf)")))
+
+(define_predicate "aarch64_maskload_else_operand"
+  (and (match_code "const_int,const_vector")
+       (match_test "op == CONST0_RTX (GET_MODE (op))")))