Message ID | 20240904033121.1895231-1-admin@levyhsu.com |
---|---|
State | New |
Headers | show |
Series | i386: Support partial vectorized FMA for V2BF/V4BF | expand |
On Wed, Sep 4, 2024 at 11:31 AM Levy Hsu <admin@levyhsu.com> wrote: > > Hi > > Bootstrapped and tested on x86-64-pc-linux-gnu. > Ok for trunk? Ok. > > This patch introduces support for vectorized FMA operations for bf16 types in > V2BF and V4BF modes on the i386 architecture. New mode iterators and > define_expand entries for fma, fnma, fms, and fnms operations are added in > mmx.md, enhancing the i386 backend to handle these complex arithmetic operations. > > gcc/ChangeLog: > > * config/i386/mmx.md (TARGET_MMX_WITH_SSE): New mode iterator VBF_32_64 > (fma<mode>4): define_expand for V2BF/V4BF fma<mode>4. > (fnma<mode>4): define_expand for V2BF/V4BF fnma<mode>4. > (fms<mode>4): define_expand for V2BF/V4BF fms<mode>4. > (fnms<mode>4): define_expand for V2BF/V4BF fnms<mode>4. > > gcc/testsuite/ChangeLog: > > * gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c: New test. > --- > gcc/config/i386/mmx.md | 84 ++++++++++++++++++- > .../i386/avx10_2-partial-bf-vector-fma-1.c | 57 +++++++++++++ > 2 files changed, 139 insertions(+), 2 deletions(-) > create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c > > diff --git a/gcc/config/i386/mmx.md b/gcc/config/i386/mmx.md > index 10fcd2beda6..22aeb43f436 100644 > --- a/gcc/config/i386/mmx.md > +++ b/gcc/config/i386/mmx.md > @@ -2636,6 +2636,88 @@ > DONE; > }) > > +(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")]) > + > +(define_expand "fma<mode>4" > + [(set (match_operand:VBF_32_64 0 "register_operand") > + (fma:VBF_32_64 > + (match_operand:VBF_32_64 1 "nonimmediate_operand") > + (match_operand:VBF_32_64 2 "nonimmediate_operand") > + (match_operand:VBF_32_64 3 "nonimmediate_operand")))] > + "TARGET_AVX10_2_256" > +{ > + rtx op0 = gen_reg_rtx (V8BFmode); > + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); > + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); > + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); > + > + emit_insn (gen_fmav8bf4 (op0, op1, op2, op3)); > + > + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); > + DONE; > +}) > + > +(define_expand "fms<mode>4" > + [(set (match_operand:VBF_32_64 0 "register_operand") > + (fma:VBF_32_64 > + (match_operand:VBF_32_64 1 "nonimmediate_operand") > + (match_operand:VBF_32_64 2 "nonimmediate_operand") > + (neg:VBF_32_64 > + (match_operand:VBF_32_64 3 "nonimmediate_operand"))))] > + "TARGET_AVX10_2_256" > +{ > + rtx op0 = gen_reg_rtx (V8BFmode); > + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); > + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); > + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); > + > + emit_insn (gen_fmsv8bf4 (op0, op1, op2, op3)); > + > + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); > + DONE; > +}) > + > +(define_expand "fnma<mode>4" > + [(set (match_operand:VBF_32_64 0 "register_operand") > + (fma:VBF_32_64 > + (neg:VBF_32_64 > + (match_operand:VBF_32_64 1 "nonimmediate_operand")) > + (match_operand:VBF_32_64 2 "nonimmediate_operand") > + (match_operand:VBF_32_64 3 "nonimmediate_operand")))] > + "TARGET_AVX10_2_256" > +{ > + rtx op0 = gen_reg_rtx (V8BFmode); > + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); > + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); > + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); > + > + emit_insn (gen_fnmav8bf4 (op0, op1, op2, op3)); > + > + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); > + DONE; > +}) > + > +(define_expand "fnms<mode>4" > + [(set (match_operand:VBF_32_64 0 "register_operand") > + (fma:VBF_32_64 > + (neg:VBF_32_64 > + (match_operand:VBF_32_64 1 "nonimmediate_operand")) > + (match_operand:VBF_32_64 2 "nonimmediate_operand") > + (neg:VBF_32_64 > + (match_operand:VBF_32_64 3 "nonimmediate_operand"))))] > + "TARGET_AVX10_2_256" > +{ > + rtx op0 = gen_reg_rtx (V8BFmode); > + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); > + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); > + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); > + > + emit_insn (gen_fnmsv8bf4 (op0, op1, op2, op3)); > + > + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); > + DONE; > +}) > + > ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; > ;; > ;; Parallel half-precision floating point complex type operations > @@ -6670,8 +6752,6 @@ > (set_attr "modrm" "0") > (set_attr "memory" "none")]) > > -(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")]) > - > ;; VDIVNEPBF16 does not generate floating point exceptions. > (define_expand "<insn><mode>3" > [(set (match_operand:VBF_32_64 0 "register_operand") > diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c > new file mode 100644 > index 00000000000..72e17e99603 > --- /dev/null > +++ b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c > @@ -0,0 +1,57 @@ > +/* { dg-do compile } */ > +/* { dg-options "-mavx10.2 -O2" } */ > +/* { dg-final { scan-assembler-times "vfmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ > +/* { dg-final { scan-assembler-times "vfmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ > +/* { dg-final { scan-assembler-times "vfnmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ > +/* { dg-final { scan-assembler-times "vfnmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ > + > +typedef __bf16 v4bf __attribute__ ((__vector_size__ (8))); > +typedef __bf16 v2bf __attribute__ ((__vector_size__ (4))); > + > +v4bf > +foo_madd_64 (v4bf a, v4bf b, v4bf c) > +{ > + return a * b + c; > +} > + > +v4bf > +foo_msub_64 (v4bf a, v4bf b, v4bf c) > +{ > + return a * b - c; > +} > + > +v4bf > +foo_nmadd_64 (v4bf a, v4bf b, v4bf c) > +{ > + return -a * b + c; > +} > + > +v4bf > +foo_nmsub_64 (v4bf a, v4bf b, v4bf c) > +{ > + return -a * b - c; > +} > + > +v2bf > +foo_madd_32 (v2bf a, v2bf b, v2bf c) > +{ > + return a * b + c; > +} > + > +v2bf > +foo_msub_32 (v2bf a, v2bf b, v2bf c) > +{ > + return a * b - c; > +} > + > +v2bf > +foo_nmadd_32 (v2bf a, v2bf b, v2bf c) > +{ > + return -a * b + c; > +} > + > +v2bf > +foo_nmsub_32 (v2bf a, v2bf b, v2bf c) > +{ > + return -a * b - c; > +} > -- > 2.31.1 >
diff --git a/gcc/config/i386/mmx.md b/gcc/config/i386/mmx.md index 10fcd2beda6..22aeb43f436 100644 --- a/gcc/config/i386/mmx.md +++ b/gcc/config/i386/mmx.md @@ -2636,6 +2636,88 @@ DONE; }) +(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")]) + +(define_expand "fma<mode>4" + [(set (match_operand:VBF_32_64 0 "register_operand") + (fma:VBF_32_64 + (match_operand:VBF_32_64 1 "nonimmediate_operand") + (match_operand:VBF_32_64 2 "nonimmediate_operand") + (match_operand:VBF_32_64 3 "nonimmediate_operand")))] + "TARGET_AVX10_2_256" +{ + rtx op0 = gen_reg_rtx (V8BFmode); + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); + + emit_insn (gen_fmav8bf4 (op0, op1, op2, op3)); + + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); + DONE; +}) + +(define_expand "fms<mode>4" + [(set (match_operand:VBF_32_64 0 "register_operand") + (fma:VBF_32_64 + (match_operand:VBF_32_64 1 "nonimmediate_operand") + (match_operand:VBF_32_64 2 "nonimmediate_operand") + (neg:VBF_32_64 + (match_operand:VBF_32_64 3 "nonimmediate_operand"))))] + "TARGET_AVX10_2_256" +{ + rtx op0 = gen_reg_rtx (V8BFmode); + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); + + emit_insn (gen_fmsv8bf4 (op0, op1, op2, op3)); + + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); + DONE; +}) + +(define_expand "fnma<mode>4" + [(set (match_operand:VBF_32_64 0 "register_operand") + (fma:VBF_32_64 + (neg:VBF_32_64 + (match_operand:VBF_32_64 1 "nonimmediate_operand")) + (match_operand:VBF_32_64 2 "nonimmediate_operand") + (match_operand:VBF_32_64 3 "nonimmediate_operand")))] + "TARGET_AVX10_2_256" +{ + rtx op0 = gen_reg_rtx (V8BFmode); + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); + + emit_insn (gen_fnmav8bf4 (op0, op1, op2, op3)); + + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); + DONE; +}) + +(define_expand "fnms<mode>4" + [(set (match_operand:VBF_32_64 0 "register_operand") + (fma:VBF_32_64 + (neg:VBF_32_64 + (match_operand:VBF_32_64 1 "nonimmediate_operand")) + (match_operand:VBF_32_64 2 "nonimmediate_operand") + (neg:VBF_32_64 + (match_operand:VBF_32_64 3 "nonimmediate_operand"))))] + "TARGET_AVX10_2_256" +{ + rtx op0 = gen_reg_rtx (V8BFmode); + rtx op1 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[1]), <MODE>mode); + rtx op2 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[2]), <MODE>mode); + rtx op3 = lowpart_subreg (V8BFmode, force_reg (<MODE>mode, operands[3]), <MODE>mode); + + emit_insn (gen_fnmsv8bf4 (op0, op1, op2, op3)); + + emit_move_insn (operands[0], lowpart_subreg (<MODE>mode, op0, V8BFmode)); + DONE; +}) + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;; Parallel half-precision floating point complex type operations @@ -6670,8 +6752,6 @@ (set_attr "modrm" "0") (set_attr "memory" "none")]) -(define_mode_iterator VBF_32_64 [V2BF (V4BF "TARGET_MMX_WITH_SSE")]) - ;; VDIVNEPBF16 does not generate floating point exceptions. (define_expand "<insn><mode>3" [(set (match_operand:VBF_32_64 0 "register_operand") diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c new file mode 100644 index 00000000000..72e17e99603 --- /dev/null +++ b/gcc/testsuite/gcc.target/i386/avx10_2-partial-bf-vector-fma-1.c @@ -0,0 +1,57 @@ +/* { dg-do compile } */ +/* { dg-options "-mavx10.2 -O2" } */ +/* { dg-final { scan-assembler-times "vfmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ +/* { dg-final { scan-assembler-times "vfmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ +/* { dg-final { scan-assembler-times "vfnmadd132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ +/* { dg-final { scan-assembler-times "vfnmsub132nepbf16\[ \\t\]+\[^\{\n\]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+\[^\n\r]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 2 } } */ + +typedef __bf16 v4bf __attribute__ ((__vector_size__ (8))); +typedef __bf16 v2bf __attribute__ ((__vector_size__ (4))); + +v4bf +foo_madd_64 (v4bf a, v4bf b, v4bf c) +{ + return a * b + c; +} + +v4bf +foo_msub_64 (v4bf a, v4bf b, v4bf c) +{ + return a * b - c; +} + +v4bf +foo_nmadd_64 (v4bf a, v4bf b, v4bf c) +{ + return -a * b + c; +} + +v4bf +foo_nmsub_64 (v4bf a, v4bf b, v4bf c) +{ + return -a * b - c; +} + +v2bf +foo_madd_32 (v2bf a, v2bf b, v2bf c) +{ + return a * b + c; +} + +v2bf +foo_msub_32 (v2bf a, v2bf b, v2bf c) +{ + return a * b - c; +} + +v2bf +foo_nmadd_32 (v2bf a, v2bf b, v2bf c) +{ + return -a * b + c; +} + +v2bf +foo_nmsub_32 (v2bf a, v2bf b, v2bf c) +{ + return -a * b - c; +}