Message ID | 20230714062413.2277485-1-haochen.jiang@intel.com |
---|---|
State | New |
Headers | show |
Series | i386: Auto vectorize usdot_prod, udot_prod with AVXVNNIINT16 instruction. | expand |
On Fri, Jul 14, 2023 at 8:24 AM Haochen Jiang <haochen.jiang@intel.com> wrote: > > Hi all, > > This patch aims to auto vectorize usdot_prod and udot_prod with newly > introduced AVX-VNNI-INT16. > > Also I refined the redundant mode iterator in the patch. > > Regtested on x86_64-pc-linux-gnu. Ok for trunk after AVX-VNNI-INT16 patch > checked in? > > BRs, > Haochen > > gcc/ChangeLog: > > * config/i386/sse.md (VI2_AVX2): Delete V32HI since we actually > have the same iterator. Also renaming all the occurence to > VI2_AVX2_AVX512BW. > (usdot_prod<mode>): New define_expand. > (udot_prod<mode>): Ditto. > > gcc/testsuite/ChangeLog: > > * gcc.target/i386/vnniint16-auto-vectorize-1.c: New test. > * gcc.target/i386/vnniint16-auto-vectorize-2.c: Ditto. OK with two changes below. Thanks, Uros. > --- > gcc/config/i386/sse.md | 98 +++++++++++++------ > .../i386/vnniint16-auto-vectorize-1.c | 28 ++++++ > .../i386/vnniint16-auto-vectorize-2.c | 76 ++++++++++++++ > 3 files changed, 172 insertions(+), 30 deletions(-) > create mode 100644 gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c > create mode 100644 gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c > > diff --git a/gcc/config/i386/sse.md b/gcc/config/i386/sse.md > index 7471932b27e..98e7f9334bc 100644 > --- a/gcc/config/i386/sse.md > +++ b/gcc/config/i386/sse.md > @@ -545,6 +545,9 @@ > V32HI (V16HI "TARGET_AVX512VL")]) > > (define_mode_iterator VI2_AVX2 > + [(V16HI "TARGET_AVX2") V8HI]) > + > +(define_mode_iterator VI2_AVX2_AVX512BW > [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI]) > > (define_mode_iterator VI2_AVX512F > @@ -637,9 +640,6 @@ > (V16HI "TARGET_AVX2") V8HI > (V8SI "TARGET_AVX2") V4SI]) > > -(define_mode_iterator VI2_AVX2_AVX512BW > - [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI]) > - > (define_mode_iterator VI248_AVX512VL > [V32HI V16SI V8DI > (V16HI "TARGET_AVX512VL") (V8SI "TARGET_AVX512VL") > @@ -15298,16 +15298,16 @@ > }) > > (define_expand "mul<mode>3<mask_name>" > - [(set (match_operand:VI2_AVX2 0 "register_operand") > - (mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand") > - (match_operand:VI2_AVX2 2 "vector_operand")))] > + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand") > + (mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand") > + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand")))] > "TARGET_SSE2 && <mask_mode512bit_condition> && <mask_avx512bw_condition>" > "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);") > > (define_insn "*mul<mode>3<mask_name>" > - [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>") > - (mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>") > - (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))] > + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>") > + (mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>") > + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))] > "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2])) > && <mask_mode512bit_condition> && <mask_avx512bw_condition>" > "@ > @@ -15320,28 +15320,28 @@ > (set_attr "mode" "<sseinsnmode>")]) > > (define_expand "<s>mul<mode>3_highpart<mask_name>" > - [(set (match_operand:VI2_AVX2 0 "register_operand") > - (truncate:VI2_AVX2 > + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand") > + (truncate:VI2_AVX2_AVX512BW > (lshiftrt:<ssedoublemode> > (mult:<ssedoublemode> > (any_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 1 "vector_operand")) > + (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand")) > (any_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 2 "vector_operand"))) > + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand"))) > (const_int 16))))] > "TARGET_SSE2 > && <mask_mode512bit_condition> && <mask_avx512bw_condition>" > "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);") > > (define_insn "*<s>mul<mode>3_highpart<mask_name>" > - [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>") > - (truncate:VI2_AVX2 > + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>") > + (truncate:VI2_AVX2_AVX512BW > (lshiftrt:<ssedoublemode> > (mult:<ssedoublemode> > (any_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>")) > + (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>")) > (any_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m"))) > + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m"))) > (const_int 16))))] > "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2])) > && <mask_mode512bit_condition> && <mask_avx512bw_condition>" > @@ -15591,8 +15591,8 @@ > (define_insn "avx512bw_pmaddwd512<mode><mask_name>" > [(set (match_operand:<sseunpackmode> 0 "register_operand" "=v") > (unspec:<sseunpackmode> > - [(match_operand:VI2_AVX2 1 "register_operand" "v") > - (match_operand:VI2_AVX2 2 "nonimmediate_operand" "vm")] > + [(match_operand:VI2_AVX2_AVX512BW 1 "register_operand" "v") > + (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand" "vm")] > UNSPEC_PMADDWD512))] > "TARGET_AVX512BW && <mask_mode512bit_condition>" > "vpmaddwd\t{%2, %1, %0<mask_operand3>|%0<mask_operand3>, %1, %2}"; > @@ -21569,16 +21569,16 @@ > }) > > (define_expand "smulhrs<mode>3" > - [(set (match_operand:VI2_AVX2 0 "register_operand") > - (truncate:VI2_AVX2 > + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand") > + (truncate:VI2_AVX2_AVX512BW > (lshiftrt:<ssedoublemode> > (plus:<ssedoublemode> > (lshiftrt:<ssedoublemode> > (mult:<ssedoublemode> > (sign_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 1 "nonimmediate_operand")) > + (match_operand:VI2_AVX2_AVX512BW 1 "nonimmediate_operand")) > (sign_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 2 "nonimmediate_operand"))) > + (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand"))) > (const_int 14)) > (match_dup 3)) > (const_int 1))))] > @@ -21589,18 +21589,18 @@ > }) > > (define_insn "*<ssse3_avx2>_pmulhrsw<mode>3<mask_name>" > - [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>") > - (truncate:VI2_AVX2 > + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>") > + (truncate:VI2_AVX2_AVX512BW > (lshiftrt:<ssedoublemode> > (plus:<ssedoublemode> > (lshiftrt:<ssedoublemode> > (mult:<ssedoublemode> > (sign_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>")) > + (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>")) > (sign_extend:<ssedoublemode> > - (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m"))) > + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m"))) > (const_int 14)) > - (match_operand:VI2_AVX2 3 "const1_operand")) > + (match_operand:VI2_AVX2_AVX512BW 3 "const1_operand")) > (const_int 1))))] > "TARGET_SSSE3 && <mask_mode512bit_condition> && <mask_avx512bw_condition> > && !(MEM_P (operands[1]) && MEM_P (operands[2]))" > @@ -22327,8 +22327,8 @@ > (set_attr "mode" "<sseinsnmode>")]) > > (define_insn "<sse4_1_avx2>_packusdw<mask_name>" > - [(set (match_operand:VI2_AVX2 0 "register_operand" "=Yr,*x,<v_Yw>") > - (unspec:VI2_AVX2 > + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=Yr,*x,<v_Yw>") > + (unspec:VI2_AVX2_AVX512BW > [(match_operand:<sseunpackmode> 1 "register_operand" "0,0,<v_Yw>") > (match_operand:<sseunpackmode> 2 "vector_operand" "YrBm,*xBm,<v_Yw>m")] > UNSPEC_US_TRUNCATE))] > @@ -30340,6 +30340,44 @@ > (UNSPEC_VPDPWSUD "wsud") (UNSPEC_VPDPWSUDS "wsuds") > (UNSPEC_VPDPWUUD "wuud") (UNSPEC_VPDPWUUDS "wuuds")]) > > +(define_expand "usdot_prod<mode>" > + [(match_operand:<sseunpackmode> 0 "register_operand") > + (match_operand:VI2_AVX2 1 "register_operand") > + (match_operand:VI2_AVX2 2 "register_operand") > + (match_operand:<sseunpackmode> 3 "register_operand")] > + "TARGET_AVXVNNIINT16" > +{ > + operands[1] = lowpart_subreg (<sseunpackmode>mode, > + force_reg (<MODE>mode, operands[1]), > + <MODE>mode); > + operands[2] = lowpart_subreg (<sseunpackmode>mode, > + force_reg (<MODE>mode, operands[2]), > + <MODE>mode); > + emit_insn (gen_rtx_SET (operands[0], operands[3])); You don't have to emit a move, the register allocator will do that for you. > + emit_insn (gen_vpdpwusd_<SDOT_VPDP_SUF> (operands[0], operands[3], > + operands[1], operands[2])); > + DONE; > +}) > + > +(define_expand "udot_prod<mode>" > + [(match_operand:<sseunpackmode> 0 "register_operand") > + (match_operand:VI2_AVX2 1 "register_operand") > + (match_operand:VI2_AVX2 2 "register_operand") > + (match_operand:<sseunpackmode> 3 "register_operand")] > + "TARGET_AVXVNNIINT16" > +{ > + operands[1] = lowpart_subreg (<sseunpackmode>mode, > + force_reg (<MODE>mode, operands[1]), > + <MODE>mode); > + operands[2] = lowpart_subreg (<sseunpackmode>mode, > + force_reg (<MODE>mode, operands[2]), > + <MODE>mode); > + emit_insn (gen_rtx_SET (operands[0], operands[3])); Also here, the above is not needed. > + emit_insn (gen_vpdpwuud_<SDOT_VPDP_SUF> (operands[0], operands[3], > + operands[1], operands[2])); > + DONE; > +}) > + > (define_insn "vpdp<vpdpwprodtype>_<mode>" > [(set (match_operand:VI4_AVX 0 "register_operand" "=x") > (unspec:VI4_AVX > diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c > new file mode 100644 > index 00000000000..73f0d3296aa > --- /dev/null > +++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c > @@ -0,0 +1,28 @@ > +/* { dg-do compile } */ > +/* { dg-options "-mavxvnniint16 -O2" } */ > +/* { dg-final { scan-assembler "vpdpwusd\t" } } */ > +/* { dg-final { scan-assembler "vpdpwuud\t" } } */ > + > +int __attribute__((noinline, noclone, optimize("tree-vectorize"))) > +usdot_prod_hi (unsigned short * restrict a, short * restrict b, > + int c, int n) > +{ > + int i; > + for (i = 0; i < n; i++) > + { > + c += ((int) a[i] * (int) b[i]); > + } > + return c; > +} > + > +int __attribute__((noinline, noclone, optimize("tree-vectorize"))) > +udot_prod_hi (unsigned short * restrict a, unsigned short *restrict b, > + int c, int n) > +{ > + int i; > + for (i = 0; i < n; i++) > + { > + c += ((int) a[i] * (int) b[i]); > + } > + return c; > +} > diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c > new file mode 100644 > index 00000000000..90dc0eade7e > --- /dev/null > +++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c > @@ -0,0 +1,76 @@ > +/* { dg-do run } */ > +/* { dg-options "-O2 -mavxvnniint16" } */ > +/* { dg-require-effective-target avxvnniint16 } */ > + > +#define AVXVNNIINT16 > +#ifndef CHECK > +#define CHECK "avx-check.h" > +#endif > + > +#ifndef TEST > +#define TEST avx_test > +#endif > + > +#include CHECK > +#include "vnniint16-auto-vectorize-1.c" > + > +#define N 256 > + > +short a_i16[N]; > +unsigned short b_u16[N], c_u16[N], d_u16[N]; > +int i16_exp, i16_ref; > + > +int __attribute__((noinline, noclone, optimize("no-tree-vectorize"))) > +udot_prod_hi_scalar (unsigned short * restrict a, unsigned short * restrict b, > + int c, int n) > +{ > + int i; > + for (i = 0; i < n; i++) > + { > + c += ((int) a[i] * (int) b[i]); > + } > + return c; > +} > + > +int __attribute__((noinline, noclone, optimize("no-tree-vectorize"))) > +usdot_prod_hi_scalar (unsigned short * restrict a, short *restrict b, > + int c, int n) > +{ > + int i; > + for (i = 0; i < n; i++) > + { > + c += ((int) a[i] * (int) b[i]); > + } > + return c; > +} > + > +void init () > +{ > + int i; > + > + i16_exp = i16_ref = 65535; > + > + for (i = 0; i < N; i++) > + { > + a_i16[i] = -i + 2; > + b_u16[i] = i * 2; > + c_u16[i] = i * 3; > + d_u16[i] = i * 4; > + } > +} > + > +void > +TEST (void) > +{ > + init (); > + i16_exp = usdot_prod_hi (a_i16, b_u16, i16_exp, N); > + i16_ref = usdot_prod_hi_scalar (a_i16, b_u16, i16_ref, N); > + if (i16_exp != i16_ref) > + abort (); > + > + init (); > + i16_exp = udot_prod_hi (c_u16, d_u16, i16_exp, N); > + i16_ref = udot_prod_hi_scalar (c_u16, d_u16, i16_ref, N); > + if (i16_exp != i16_ref) > + abort (); > +} > -- > 2.31.1 >
diff --git a/gcc/config/i386/sse.md b/gcc/config/i386/sse.md index 7471932b27e..98e7f9334bc 100644 --- a/gcc/config/i386/sse.md +++ b/gcc/config/i386/sse.md @@ -545,6 +545,9 @@ V32HI (V16HI "TARGET_AVX512VL")]) (define_mode_iterator VI2_AVX2 + [(V16HI "TARGET_AVX2") V8HI]) + +(define_mode_iterator VI2_AVX2_AVX512BW [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI]) (define_mode_iterator VI2_AVX512F @@ -637,9 +640,6 @@ (V16HI "TARGET_AVX2") V8HI (V8SI "TARGET_AVX2") V4SI]) -(define_mode_iterator VI2_AVX2_AVX512BW - [(V32HI "TARGET_AVX512BW") (V16HI "TARGET_AVX2") V8HI]) - (define_mode_iterator VI248_AVX512VL [V32HI V16SI V8DI (V16HI "TARGET_AVX512VL") (V8SI "TARGET_AVX512VL") @@ -15298,16 +15298,16 @@ }) (define_expand "mul<mode>3<mask_name>" - [(set (match_operand:VI2_AVX2 0 "register_operand") - (mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand") - (match_operand:VI2_AVX2 2 "vector_operand")))] + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand") + (mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand") + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand")))] "TARGET_SSE2 && <mask_mode512bit_condition> && <mask_avx512bw_condition>" "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);") (define_insn "*mul<mode>3<mask_name>" - [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>") - (mult:VI2_AVX2 (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>") - (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m")))] + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>") + (mult:VI2_AVX2_AVX512BW (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>") + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m")))] "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2])) && <mask_mode512bit_condition> && <mask_avx512bw_condition>" "@ @@ -15320,28 +15320,28 @@ (set_attr "mode" "<sseinsnmode>")]) (define_expand "<s>mul<mode>3_highpart<mask_name>" - [(set (match_operand:VI2_AVX2 0 "register_operand") - (truncate:VI2_AVX2 + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand") + (truncate:VI2_AVX2_AVX512BW (lshiftrt:<ssedoublemode> (mult:<ssedoublemode> (any_extend:<ssedoublemode> - (match_operand:VI2_AVX2 1 "vector_operand")) + (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand")) (any_extend:<ssedoublemode> - (match_operand:VI2_AVX2 2 "vector_operand"))) + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand"))) (const_int 16))))] "TARGET_SSE2 && <mask_mode512bit_condition> && <mask_avx512bw_condition>" "ix86_fixup_binary_operands_no_copy (MULT, <MODE>mode, operands);") (define_insn "*<s>mul<mode>3_highpart<mask_name>" - [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>") - (truncate:VI2_AVX2 + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>") + (truncate:VI2_AVX2_AVX512BW (lshiftrt:<ssedoublemode> (mult:<ssedoublemode> (any_extend:<ssedoublemode> - (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>")) + (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>")) (any_extend:<ssedoublemode> - (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m"))) + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m"))) (const_int 16))))] "TARGET_SSE2 && !(MEM_P (operands[1]) && MEM_P (operands[2])) && <mask_mode512bit_condition> && <mask_avx512bw_condition>" @@ -15591,8 +15591,8 @@ (define_insn "avx512bw_pmaddwd512<mode><mask_name>" [(set (match_operand:<sseunpackmode> 0 "register_operand" "=v") (unspec:<sseunpackmode> - [(match_operand:VI2_AVX2 1 "register_operand" "v") - (match_operand:VI2_AVX2 2 "nonimmediate_operand" "vm")] + [(match_operand:VI2_AVX2_AVX512BW 1 "register_operand" "v") + (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand" "vm")] UNSPEC_PMADDWD512))] "TARGET_AVX512BW && <mask_mode512bit_condition>" "vpmaddwd\t{%2, %1, %0<mask_operand3>|%0<mask_operand3>, %1, %2}"; @@ -21569,16 +21569,16 @@ }) (define_expand "smulhrs<mode>3" - [(set (match_operand:VI2_AVX2 0 "register_operand") - (truncate:VI2_AVX2 + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand") + (truncate:VI2_AVX2_AVX512BW (lshiftrt:<ssedoublemode> (plus:<ssedoublemode> (lshiftrt:<ssedoublemode> (mult:<ssedoublemode> (sign_extend:<ssedoublemode> - (match_operand:VI2_AVX2 1 "nonimmediate_operand")) + (match_operand:VI2_AVX2_AVX512BW 1 "nonimmediate_operand")) (sign_extend:<ssedoublemode> - (match_operand:VI2_AVX2 2 "nonimmediate_operand"))) + (match_operand:VI2_AVX2_AVX512BW 2 "nonimmediate_operand"))) (const_int 14)) (match_dup 3)) (const_int 1))))] @@ -21589,18 +21589,18 @@ }) (define_insn "*<ssse3_avx2>_pmulhrsw<mode>3<mask_name>" - [(set (match_operand:VI2_AVX2 0 "register_operand" "=x,<v_Yw>") - (truncate:VI2_AVX2 + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=x,<v_Yw>") + (truncate:VI2_AVX2_AVX512BW (lshiftrt:<ssedoublemode> (plus:<ssedoublemode> (lshiftrt:<ssedoublemode> (mult:<ssedoublemode> (sign_extend:<ssedoublemode> - (match_operand:VI2_AVX2 1 "vector_operand" "%0,<v_Yw>")) + (match_operand:VI2_AVX2_AVX512BW 1 "vector_operand" "%0,<v_Yw>")) (sign_extend:<ssedoublemode> - (match_operand:VI2_AVX2 2 "vector_operand" "xBm,<v_Yw>m"))) + (match_operand:VI2_AVX2_AVX512BW 2 "vector_operand" "xBm,<v_Yw>m"))) (const_int 14)) - (match_operand:VI2_AVX2 3 "const1_operand")) + (match_operand:VI2_AVX2_AVX512BW 3 "const1_operand")) (const_int 1))))] "TARGET_SSSE3 && <mask_mode512bit_condition> && <mask_avx512bw_condition> && !(MEM_P (operands[1]) && MEM_P (operands[2]))" @@ -22327,8 +22327,8 @@ (set_attr "mode" "<sseinsnmode>")]) (define_insn "<sse4_1_avx2>_packusdw<mask_name>" - [(set (match_operand:VI2_AVX2 0 "register_operand" "=Yr,*x,<v_Yw>") - (unspec:VI2_AVX2 + [(set (match_operand:VI2_AVX2_AVX512BW 0 "register_operand" "=Yr,*x,<v_Yw>") + (unspec:VI2_AVX2_AVX512BW [(match_operand:<sseunpackmode> 1 "register_operand" "0,0,<v_Yw>") (match_operand:<sseunpackmode> 2 "vector_operand" "YrBm,*xBm,<v_Yw>m")] UNSPEC_US_TRUNCATE))] @@ -30340,6 +30340,44 @@ (UNSPEC_VPDPWSUD "wsud") (UNSPEC_VPDPWSUDS "wsuds") (UNSPEC_VPDPWUUD "wuud") (UNSPEC_VPDPWUUDS "wuuds")]) +(define_expand "usdot_prod<mode>" + [(match_operand:<sseunpackmode> 0 "register_operand") + (match_operand:VI2_AVX2 1 "register_operand") + (match_operand:VI2_AVX2 2 "register_operand") + (match_operand:<sseunpackmode> 3 "register_operand")] + "TARGET_AVXVNNIINT16" +{ + operands[1] = lowpart_subreg (<sseunpackmode>mode, + force_reg (<MODE>mode, operands[1]), + <MODE>mode); + operands[2] = lowpart_subreg (<sseunpackmode>mode, + force_reg (<MODE>mode, operands[2]), + <MODE>mode); + emit_insn (gen_rtx_SET (operands[0], operands[3])); + emit_insn (gen_vpdpwusd_<SDOT_VPDP_SUF> (operands[0], operands[3], + operands[1], operands[2])); + DONE; +}) + +(define_expand "udot_prod<mode>" + [(match_operand:<sseunpackmode> 0 "register_operand") + (match_operand:VI2_AVX2 1 "register_operand") + (match_operand:VI2_AVX2 2 "register_operand") + (match_operand:<sseunpackmode> 3 "register_operand")] + "TARGET_AVXVNNIINT16" +{ + operands[1] = lowpart_subreg (<sseunpackmode>mode, + force_reg (<MODE>mode, operands[1]), + <MODE>mode); + operands[2] = lowpart_subreg (<sseunpackmode>mode, + force_reg (<MODE>mode, operands[2]), + <MODE>mode); + emit_insn (gen_rtx_SET (operands[0], operands[3])); + emit_insn (gen_vpdpwuud_<SDOT_VPDP_SUF> (operands[0], operands[3], + operands[1], operands[2])); + DONE; +}) + (define_insn "vpdp<vpdpwprodtype>_<mode>" [(set (match_operand:VI4_AVX 0 "register_operand" "=x") (unspec:VI4_AVX diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c new file mode 100644 index 00000000000..73f0d3296aa --- /dev/null +++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-1.c @@ -0,0 +1,28 @@ +/* { dg-do compile } */ +/* { dg-options "-mavxvnniint16 -O2" } */ +/* { dg-final { scan-assembler "vpdpwusd\t" } } */ +/* { dg-final { scan-assembler "vpdpwuud\t" } } */ + +int __attribute__((noinline, noclone, optimize("tree-vectorize"))) +usdot_prod_hi (unsigned short * restrict a, short * restrict b, + int c, int n) +{ + int i; + for (i = 0; i < n; i++) + { + c += ((int) a[i] * (int) b[i]); + } + return c; +} + +int __attribute__((noinline, noclone, optimize("tree-vectorize"))) +udot_prod_hi (unsigned short * restrict a, unsigned short *restrict b, + int c, int n) +{ + int i; + for (i = 0; i < n; i++) + { + c += ((int) a[i] * (int) b[i]); + } + return c; +} diff --git a/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c new file mode 100644 index 00000000000..90dc0eade7e --- /dev/null +++ b/gcc/testsuite/gcc.target/i386/vnniint16-auto-vectorize-2.c @@ -0,0 +1,76 @@ +/* { dg-do run } */ +/* { dg-options "-O2 -mavxvnniint16" } */ +/* { dg-require-effective-target avxvnniint16 } */ + +#define AVXVNNIINT16 +#ifndef CHECK +#define CHECK "avx-check.h" +#endif + +#ifndef TEST +#define TEST avx_test +#endif + +#include CHECK +#include "vnniint16-auto-vectorize-1.c" + +#define N 256 + +short a_i16[N]; +unsigned short b_u16[N], c_u16[N], d_u16[N]; +int i16_exp, i16_ref; + +int __attribute__((noinline, noclone, optimize("no-tree-vectorize"))) +udot_prod_hi_scalar (unsigned short * restrict a, unsigned short * restrict b, + int c, int n) +{ + int i; + for (i = 0; i < n; i++) + { + c += ((int) a[i] * (int) b[i]); + } + return c; +} + +int __attribute__((noinline, noclone, optimize("no-tree-vectorize"))) +usdot_prod_hi_scalar (unsigned short * restrict a, short *restrict b, + int c, int n) +{ + int i; + for (i = 0; i < n; i++) + { + c += ((int) a[i] * (int) b[i]); + } + return c; +} + +void init () +{ + int i; + + i16_exp = i16_ref = 65535; + + for (i = 0; i < N; i++) + { + a_i16[i] = -i + 2; + b_u16[i] = i * 2; + c_u16[i] = i * 3; + d_u16[i] = i * 4; + } +} + +void +TEST (void) +{ + init (); + i16_exp = usdot_prod_hi (a_i16, b_u16, i16_exp, N); + i16_ref = usdot_prod_hi_scalar (a_i16, b_u16, i16_ref, N); + if (i16_exp != i16_ref) + abort (); + + init (); + i16_exp = udot_prod_hi (c_u16, d_u16, i16_exp, N); + i16_ref = udot_prod_hi_scalar (c_u16, d_u16, i16_ref, N); + if (i16_exp != i16_ref) + abort (); +}