@@ -458,6 +458,13 @@ aarch64_types_storestruct_lane_p_qualifiers[SIMD_MAX_BUILTIN_ARGS]
qualifier_poly, qualifier_struct_load_store_lane_index };
#define TYPES_STORESTRUCT_LANE_P (aarch64_types_storestruct_lane_p_qualifiers)
+constexpr insn_code CODE_FOR_aarch64_sdot_prodv8qi = CODE_FOR_sdot_prodv2siv8qi;
+constexpr insn_code CODE_FOR_aarch64_udot_prodv8qi = CODE_FOR_udot_prodv2siv8qi;
+constexpr insn_code CODE_FOR_aarch64_usdot_prodv8qi = CODE_FOR_usdot_prodv2siv8qi;
+constexpr insn_code CODE_FOR_aarch64_sdot_prodv16qi = CODE_FOR_sdot_prodv4siv16qi;
+constexpr insn_code CODE_FOR_aarch64_udot_prodv16qi = CODE_FOR_udot_prodv4siv16qi;
+constexpr insn_code CODE_FOR_aarch64_usdot_prodv16qi = CODE_FOR_usdot_prodv4siv16qi;
+
#define CF0(N, X) CODE_FOR_aarch64_##N##X
#define CF1(N, X) CODE_FOR_##N##X##1
#define CF2(N, X) CODE_FOR_##N##X##2
@@ -418,9 +418,9 @@
BUILTIN_VSDQ_I_DI (BINOP_UUS, urshl, 0, NONE)
/* Implemented by <sur><dotprod>_prod<dot_mode>. */
- BUILTIN_VB (TERNOP, sdot_prod, 10, NONE)
- BUILTIN_VB (TERNOPU, udot_prod, 10, NONE)
- BUILTIN_VB (TERNOP_SUSS, usdot_prod, 10, NONE)
+ BUILTIN_VB (TERNOP, sdot_prod, 0, NONE)
+ BUILTIN_VB (TERNOPU, udot_prod, 0, NONE)
+ BUILTIN_VB (TERNOP_SUSS, usdot_prod, 0, NONE)
/* Implemented by aarch64_<sur><dotprod>_lane{q}<dot_mode>. */
BUILTIN_VB (QUADOP_LANE, sdot_lane, 0, NONE)
BUILTIN_VB (QUADOPU_LANE, udot_lane, 0, NONE)
@@ -568,7 +568,7 @@ (define_expand "cmul<conj_op><mode>3"
;; ...
;;
;; and so the vectorizer provides r, in which the result has to be accumulated.
-(define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>"
+(define_insn "<sur>dot_prod<mode><vsi2qi><vczle><vczbe>"
[(set (match_operand:VS 0 "register_operand" "=w")
(plus:VS
(unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w")
@@ -582,7 +582,7 @@ (define_insn "<sur>dot_prod<vsi2qi><vczle><vczbe>"
;; These instructions map to the __builtins for the Armv8.6-a I8MM usdot
;; (vector) Dot Product operation and the vectorized optab.
-(define_insn "usdot_prod<vsi2qi><vczle><vczbe>"
+(define_insn "usdot_prod<mode><vsi2qi><vczle><vczbe>"
[(set (match_operand:VS 0 "register_operand" "=w")
(plus:VS
(unspec:VS [(match_operand:<VSI2QI> 1 "register_operand" "w")
@@ -1075,7 +1075,7 @@ (define_expand "<su>sadv16qi"
rtx ones = force_reg (V16QImode, CONST1_RTX (V16QImode));
rtx abd = gen_reg_rtx (V16QImode);
emit_insn (gen_aarch64_<su>abdv16qi (abd, operands[1], operands[2]));
- emit_insn (gen_udot_prodv16qi (operands[0], abd, ones, operands[3]));
+ emit_insn (gen_udot_prodv4siv16qi (operands[0], abd, ones, operands[3]));
DONE;
}
rtx reduc = gen_reg_rtx (V8HImode);
@@ -3528,6 +3528,7 @@ (define_expand "popcount<mode>2"
/* Generate a byte popcount. */
machine_mode mode = <bitsize> == 64 ? V8QImode : V16QImode;
+ machine_mode mode2 = <bitsize> == 64 ? V2SImode : V4SImode;
rtx tmp = gen_reg_rtx (mode);
auto icode = optab_handler (popcount_optab, mode);
emit_insn (GEN_FCN (icode) (tmp, gen_lowpart (mode, operands[1])));
@@ -3538,7 +3539,7 @@ (define_expand "popcount<mode>2"
/* For V4SI and V2SI, we can generate a UDOT with a 0 accumulator and a
1 multiplicand. For V2DI, another UAADDLP is needed. */
rtx ones = force_reg (mode, CONST1_RTX (mode));
- auto icode = optab_handler (udot_prod_optab, mode);
+ auto icode = convert_optab_handler (udot_prod_optab, mode2, mode);
mode = <bitsize> == 64 ? V2SImode : V4SImode;
rtx dest = mode == <MODE>mode ? operands[0] : gen_reg_rtx (mode);
rtx zeros = force_reg (mode, CONST0_RTX (mode));
@@ -804,15 +804,16 @@ public:
e.rotate_inputs_left (0, 3);
insn_code icode;
if (e.type_suffix_ids[1] == NUM_TYPE_SUFFIXES)
- icode = e.direct_optab_handler_for_sign (sdot_prod_optab,
- udot_prod_optab,
- 0, GET_MODE (e.args[0]));
+ icode = e.convert_optab_handler_for_sign (sdot_prod_optab,
+ udot_prod_optab,
+ 0, e.result_mode (),
+ GET_MODE (e.args[0]));
else
icode = (e.type_suffix (0).float_p
? CODE_FOR_aarch64_sve_fdotvnx4sfvnx8hf
: e.type_suffix (0).unsigned_p
- ? CODE_FOR_aarch64_sve_udotvnx4sivnx8hi
- : CODE_FOR_aarch64_sve_sdotvnx4sivnx8hi);
+ ? CODE_FOR_udot_prodvnx4sivnx8hi
+ : CODE_FOR_sdot_prodvnx4sivnx8hi);
return e.use_unpred_insn (icode);
}
};
@@ -2861,7 +2862,7 @@ public:
Hence we do the same rotation on arguments as svdot_impl does. */
e.rotate_inputs_left (0, 3);
machine_mode mode = e.vector_mode (0);
- insn_code icode = code_for_dot_prod (UNSPEC_USDOT, mode);
+ insn_code icode = code_for_dot_prod (UNSPEC_USDOT, e.result_mode (), mode);
return e.use_exact_insn (icode);
}
@@ -3745,6 +3745,23 @@ function_expander::direct_optab_handler_for_sign (optab signed_op,
return ::direct_optab_handler (op, mode);
}
+/* Choose between signed and unsigned convert optabs SIGNED_OP and
+ UNSIGNED_OP based on the signedness of type suffix SUFFIX_I, then
+ pick the appropriate optab handler for the mode. Use MODE as the
+ mode if given, otherwise use the mode of type suffix SUFFIX_I. */
+insn_code
+function_expander::convert_optab_handler_for_sign (optab signed_op,
+ optab unsigned_op,
+ unsigned int suffix_i,
+ machine_mode to_mode,
+ machine_mode from_mode)
+{
+ if (from_mode == VOIDmode)
+ from_mode = vector_mode (suffix_i);
+ optab op = type_suffix (suffix_i).unsigned_p ? unsigned_op : signed_op;
+ return ::convert_optab_handler (op, to_mode, from_mode);
+}
+
/* Return true if X overlaps any input. */
bool
function_expander::overlaps_input_p (rtx x)
@@ -659,6 +659,9 @@ public:
insn_code direct_optab_handler (optab, unsigned int = 0);
insn_code direct_optab_handler_for_sign (optab, optab, unsigned int = 0,
machine_mode = E_VOIDmode);
+ insn_code convert_optab_handler_for_sign (optab, optab, unsigned int = 0,
+ machine_mode = E_VOIDmode,
+ machine_mode = E_VOIDmode);
machine_mode result_mode () const;
@@ -7197,7 +7197,7 @@ (define_insn_and_rewrite "*cond_fnma<mode>_any"
;; -------------------------------------------------------------------------
;; Four-element integer dot-product with accumulation.
-(define_insn "<sur>dot_prod<vsi2qi>"
+(define_insn "<sur>dot_prod<mode><vsi2qi>"
[(set (match_operand:SVE_FULL_SDI 0 "register_operand")
(plus:SVE_FULL_SDI
(unspec:SVE_FULL_SDI
@@ -7235,7 +7235,7 @@ (define_insn "@aarch64_<sur>dot_prod_lane<SVE_FULL_SDI:mode><SVE_FULL_BHI:mode>"
}
)
-(define_insn "@<sur>dot_prod<vsi2qi>"
+(define_insn "@<sur>dot_prod<mode><vsi2qi>"
[(set (match_operand:VNx4SI_ONLY 0 "register_operand")
(plus:VNx4SI_ONLY
(unspec:VNx4SI_ONLY
@@ -7293,7 +7293,7 @@ (define_expand "<su>sad<vsi2qi>"
rtx ones = force_reg (<VSI2QI>mode, CONST1_RTX (<VSI2QI>mode));
rtx diff = gen_reg_rtx (<VSI2QI>mode);
emit_insn (gen_<su>abd<vsi2qi>3 (diff, operands[1], operands[2]));
- emit_insn (gen_udot_prod<vsi2qi> (operands[0], diff, ones, operands[3]));
+ emit_insn (gen_udot_prod<mode><vsi2qi> (operands[0], diff, ones, operands[3]));
DONE;
}
)
@@ -2021,7 +2021,7 @@ (define_insn "@aarch64_sve_qsub_<sve_int_op>_lane_<mode>"
)
;; Two-way dot-product.
-(define_insn "@aarch64_sve_<sur>dotvnx4sivnx8hi"
+(define_insn "<sur>dot_prodvnx4sivnx8hi"
[(set (match_operand:VNx4SI 0 "register_operand")
(plus:VNx4SI
(unspec:VNx4SI
new file mode 100644
@@ -0,0 +1,25 @@
+/* { dg-additional-options "-march=armv9.2-a+sme2 -O2 -ftree-vectorize" } */
+
+#include <stdint.h>
+
+uint32_t udot2(int n, uint16_t* data) __arm_streaming
+{
+ uint32_t sum = 0;
+ for (int i=0; i<n; i+=1) {
+ sum += data[i] * data[i];
+ }
+ return sum;
+}
+
+int32_t sdot2(int n, int16_t* data) __arm_streaming
+{
+ int32_t sum = 0;
+ for (int i=0; i<n; i+=1) {
+ sum += data[i] * data[i];
+ }
+ return sum;
+}
+
+/* { dg-final { scan-assembler-times {\tudot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */
+/* { dg-final { scan-assembler-times {\tsdot\tz[0-9]+\.s, z[0-9]+\.h, z[0-9]+\.h\n} 5 } } */
+/* { dg-final { scan-assembler-times {\twhilelo\t} 4 } } */