diff mbox series

[3/3] aarch64: Add support for fp8fma instructions

Message ID 20241106100358.3622028-4-saurabh.jha@arm.com
State New
Headers show
Series aarch64: Add fp8, fp8dot2, fp8dot4, and fp8fma acle | expand

Commit Message

Saurabh Jha Nov. 6, 2024, 10:03 a.m. UTC
The AArch64 FEAT_FP8FMA extension introduces instructions for
multiply-add of vectors.

This patch introduces the following instructions:
1. {vmlalbq|vmlaltq}_f16_mf8_fpm.
2. {vmlalbq|vmlaltq}_lane{q}_f16_mf8_fpm.
3. {vmlallbbq|vmlallbtq|vmlalltbq|vmlallttq}_f32_mf8_fpm.
4. {vmlallbbq|vmlallbtq|vmlalltbq|vmlallttq}_lane{q}_f32_mf8_fpm.

It introduces the fp8fma flag.

gcc/ChangeLog:

	* config/aarch64/aarch64-builtins.cc
	(check_simd_lane_bounds): Add support for new unspecs.
	* config/aarch64/aarch64-c.cc
	(aarch64_update_cpp_builtins): New flags.
	* config/aarch64/aarch64-option-extensions.def
	(AARCH64_OPT_EXTENSION): New flags.
	* config/aarch64/aarch64-simd-pragma-builtins.def
	(ENTRY_FMA_FPM): Macro to declare fma intrinsics.
	* config/aarch64/aarch64-simd.md: New instruction pattern for
	fp8fma instructions.
	* config/aarch64/aarch64.h
	(TARGET_FP8FMA): New flag for fp8fma instructions.
	* config/aarch64/iterators.md: New attributes and iterators.
	* doc/invoke.texi: New flag for fp8fma instructions.

gcc/testsuite/ChangeLog:

	* gcc.target/aarch64/simd/fma_fpm.c: New test.

	---
	In the instruction pattern for handling lanes, I am not doing
	any casting of the the third operand and hardcoding the '.b'
	suffix in the assembly string. Is that okay?
---
 gcc/config/aarch64/aarch64-builtins.cc        |  10 +
 gcc/config/aarch64/aarch64-c.cc               |   2 +
 .../aarch64/aarch64-option-extensions.def     |   2 +
 .../aarch64/aarch64-simd-pragma-builtins.def  |  19 ++
 gcc/config/aarch64/aarch64-simd.md            |  29 +++
 gcc/config/aarch64/aarch64.h                  |   3 +
 gcc/config/aarch64/iterators.md               |  18 ++
 gcc/doc/invoke.texi                           |   2 +
 .../gcc.target/aarch64/simd/fma_fpm.c         | 221 ++++++++++++++++++
 9 files changed, 306 insertions(+)
 create mode 100644 gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c
diff mbox series

Patch

diff --git a/gcc/config/aarch64/aarch64-builtins.cc b/gcc/config/aarch64/aarch64-builtins.cc
index ba3bffaa4f9..dc996f0563e 100644
--- a/gcc/config/aarch64/aarch64-builtins.cc
+++ b/gcc/config/aarch64/aarch64-builtins.cc
@@ -2593,6 +2593,16 @@  check_simd_lane_bounds (location_t location, const aarch64_pragma_builtins_data
 				     vector_to_index_mode_size / 4 - 1);
 	    break;
 
+	  case UNSPEC_FMLALB:
+	  case UNSPEC_FMLALT:
+	  case UNSPEC_FMLALLBB:
+	  case UNSPEC_FMLALLBT:
+	  case UNSPEC_FMLALLTB:
+	  case UNSPEC_FMLALLTT:
+	    require_immediate_range (location, index_arg, 0,
+				     vector_to_index_mode_size - 1);
+	    break;
+
 	  default:
 	    gcc_unreachable ();
 	  }
diff --git a/gcc/config/aarch64/aarch64-c.cc b/gcc/config/aarch64/aarch64-c.cc
index 3e30ba5afd9..4dc7711486f 100644
--- a/gcc/config/aarch64/aarch64-c.cc
+++ b/gcc/config/aarch64/aarch64-c.cc
@@ -263,6 +263,8 @@  aarch64_update_cpp_builtins (cpp_reader *pfile)
 
   aarch64_def_or_undef (TARGET_FP8DOT4, "__ARM_FEATURE_FP8DOT4", pfile);
 
+  aarch64_def_or_undef (TARGET_FP8FMA, "__ARM_FEATURE_FP8FMA", pfile);
+
   aarch64_def_or_undef (TARGET_LS64,
 			"__ARM_FEATURE_LS64", pfile);
   aarch64_def_or_undef (TARGET_RCPC, "__ARM_FEATURE_RCPC", pfile);
diff --git a/gcc/config/aarch64/aarch64-option-extensions.def b/gcc/config/aarch64/aarch64-option-extensions.def
index fd4d29e5df6..9806801e472 100644
--- a/gcc/config/aarch64/aarch64-option-extensions.def
+++ b/gcc/config/aarch64/aarch64-option-extensions.def
@@ -238,6 +238,8 @@  AARCH64_OPT_EXTENSION("fp8dot2", FP8DOT2, (SIMD), (), (), "fp8dot2")
 
 AARCH64_OPT_EXTENSION("fp8dot4", FP8DOT4, (SIMD), (), (), "fp8dot4")
 
+AARCH64_OPT_EXTENSION("fp8fma", FP8FMA, (SIMD), (), (), "fp8fma")
+
 AARCH64_OPT_EXTENSION("faminmax", FAMINMAX, (SIMD), (), (), "faminmax")
 
 #undef AARCH64_OPT_FMV_EXTENSION
diff --git a/gcc/config/aarch64/aarch64-simd-pragma-builtins.def b/gcc/config/aarch64/aarch64-simd-pragma-builtins.def
index 9dea2939b47..a85a4c48dbd 100644
--- a/gcc/config/aarch64/aarch64-simd-pragma-builtins.def
+++ b/gcc/config/aarch64/aarch64-simd-pragma-builtins.def
@@ -48,6 +48,15 @@ 
   ENTRY_TERNARY_FPM_LANE (vdotq_laneq_##T##_mf8_fpm, ternary_fpm_lane, T##q, \
 			  T##q, f8q, f8q, U)
 
+#undef ENTRY_FMA_FPM
+#define ENTRY_FMA_FPM(N, T, U)						\
+  ENTRY_TERNARY_FPM (N##_##T##_mf8_fpm, ternary_fpm,			\
+		     T##q, T##q, f8q, f8q, U)				\
+  ENTRY_TERNARY_FPM_LANE (N##_lane_##T##_mf8_fpm, ternary_fpm_lane,	\
+			  T##q, T##q, f8q, f8, U)			\
+  ENTRY_TERNARY_FPM_LANE (N##_laneq_##T##_mf8_fpm, ternary_fpm_lane,	\
+			  T##q, T##q, f8q, f8q, U)
+
 #undef ENTRY_UNARY_FPM
 #define ENTRY_UNARY_FPM(N, S, T0, T1, U) \
   ENTRY (N, S, T0, T1, none, none, none, U)
@@ -121,3 +130,13 @@  ENTRY_VDOT_FPM (f16, UNSPEC_VDOT2)
 #define REQUIRED_EXTENSIONS nonstreaming_only (AARCH64_FL_FP8DOT4)
 ENTRY_VDOT_FPM (f32, UNSPEC_VDOT4)
 #undef REQUIRED_EXTENSIONS
+
+// fp8 multiply-add
+#define REQUIRED_EXTENSIONS nonstreaming_only (AARCH64_FL_FP8FMA)
+ENTRY_FMA_FPM (vmlalbq, f16, UNSPEC_FMLALB)
+ENTRY_FMA_FPM (vmlaltq, f16, UNSPEC_FMLALT)
+ENTRY_FMA_FPM (vmlallbbq, f32, UNSPEC_FMLALLBB)
+ENTRY_FMA_FPM (vmlallbtq, f32, UNSPEC_FMLALLBT)
+ENTRY_FMA_FPM (vmlalltbq, f32, UNSPEC_FMLALLTB)
+ENTRY_FMA_FPM (vmlallttq, f32, UNSPEC_FMLALLTT)
+#undef REQUIRED_EXTENSIONS
diff --git a/gcc/config/aarch64/aarch64-simd.md b/gcc/config/aarch64/aarch64-simd.md
index ea1ef4963d2..abe607e5074 100644
--- a/gcc/config/aarch64/aarch64-simd.md
+++ b/gcc/config/aarch64/aarch64-simd.md
@@ -10125,3 +10125,32 @@ 
   "TARGET_FP8DOT4"
   "<fpm_uns_op>\t%1.<VDQSF:Vtype>, %2.<VB:Vtype>, %3.<VDQSF:Vdotlanetype>[%4]"
 )
+
+;; fpm fma instructions.
+(define_insn
+  "@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><V16QI_ONLY:mode>"
+  [(set (match_operand:VQ_HSF 0 "register_operand" "=w")
+	(unspec:VQ_HSF
+	 [(match_operand:VQ_HSF 1 "register_operand" "w")
+	  (match_operand:V16QI_ONLY 2 "register_operand" "w")
+	  (match_operand:V16QI_ONLY 3 "register_operand" "w")
+	  (reg:DI FPM_REGNUM)]
+	FPM_FMA_UNS))]
+  "TARGET_FP8FMA"
+  "<fpm_uns_op>\t%1.<VQ_HSF:Vtype>, %2.<V16QI_ONLY:Vtype>, %3.<V16QI_ONLY:Vtype>"
+)
+
+;; fpm fma instructions with lane.
+(define_insn
+  "@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><VB:mode><SI_ONLY:mode>"
+  [(set (match_operand:VQ_HSF 0 "register_operand" "=w")
+	(unspec:VQ_HSF
+	 [(match_operand:VQ_HSF 1 "register_operand" "w")
+	  (match_operand:V16QI_ONLY 2 "register_operand" "w")
+	  (match_operand:VB 3 "register_operand" "w")
+	  (match_operand:SI_ONLY 4 "const_int_operand" "n")
+	  (reg:DI FPM_REGNUM)]
+	FPM_FMA_UNS))]
+  "TARGET_FP8FMA"
+  "<fpm_uns_op>\t%1.<VQ_HSF:Vtype>, %2.<V16QI_ONLY:Vtype>, %3.b[%4]"
+)
diff --git a/gcc/config/aarch64/aarch64.h b/gcc/config/aarch64/aarch64.h
index bbe56afcb62..1d34972e7d3 100644
--- a/gcc/config/aarch64/aarch64.h
+++ b/gcc/config/aarch64/aarch64.h
@@ -489,6 +489,9 @@  constexpr auto AARCH64_FL_DEFAULT_ISA_MODE ATTRIBUTE_UNUSED
 /* fp8 dot product instructions are enabled through +fp8dot4.  */
 #define TARGET_FP8DOT4 AARCH64_HAVE_ISA (FP8DOT4)
 
+/* fp8 multiply-add instructions are enabled through +fp8fma.  */
+#define TARGET_FP8FMA AARCH64_HAVE_ISA (FP8FMA)
+
 /* Standard register usage.  */
 
 /* 31 64-bit general purpose registers R0-R30:
diff --git a/gcc/config/aarch64/iterators.md b/gcc/config/aarch64/iterators.md
index 45b9e74c231..ed8b2d93a0a 100644
--- a/gcc/config/aarch64/iterators.md
+++ b/gcc/config/aarch64/iterators.md
@@ -716,6 +716,10 @@ 
     UNSPEC_FMINNMV	; Used in aarch64-simd.md.
     UNSPEC_FMINV	; Used in aarch64-simd.md.
     UNSPEC_FADDV	; Used in aarch64-simd.md.
+    UNSPEC_FMLALLBB	; Used in aarch64-simd.md.
+    UNSPEC_FMLALLBT	; Used in aarch64-simd.md.
+    UNSPEC_FMLALLTB	; Used in aarch64-simd.md.
+    UNSPEC_FMLALLTT	; Used in aarch64-simd.md.
     UNSPEC_FNEG		; Used in aarch64-simd.md.
     UNSPEC_FSCALE	; Used in aarch64-simd.md.
     UNSPEC_ADDV		; Used in aarch64-simd.md.
@@ -4613,8 +4617,22 @@ 
 
 (define_int_iterator FPM_VDOT4_UNS [UNSPEC_VDOT4])
 
+(define_int_iterator FPM_FMA_UNS
+  [UNSPEC_FMLALB
+   UNSPEC_FMLALT
+   UNSPEC_FMLALLBB
+   UNSPEC_FMLALLBT
+   UNSPEC_FMLALLTB
+   UNSPEC_FMLALLTT])
+
 (define_int_attr fpm_uns_op
   [(UNSPEC_FSCALE "fscale")
+   (UNSPEC_FMLALB "fmlalb")
+   (UNSPEC_FMLALT "fmlalt")
+   (UNSPEC_FMLALLBB "fmlallbb")
+   (UNSPEC_FMLALLBT "fmlallbt")
+   (UNSPEC_FMLALLTB "fmlalltb")
+   (UNSPEC_FMLALLTT "fmlalltt")
    (UNSPEC_VCVT_F16 "fcvtn")
    (UNSPEC_VCVTQ_F16 "fcvtn")
    (UNSPEC_VCVT_F32 "fcvtn")
diff --git a/gcc/doc/invoke.texi b/gcc/doc/invoke.texi
index 332c664b30f..d198e258218 100644
--- a/gcc/doc/invoke.texi
+++ b/gcc/doc/invoke.texi
@@ -21809,6 +21809,8 @@  Enable the fp8 (8-bit floating point) extension.
 Enable the fp8dot2 (8-bit floating point dot product) extension.
 @item fp8dot4
 Enable the fp8dot4 (8-bit floating point dot product) extension.
+@item fp8fma
+Enable the fp8fma (8-bit floating point multiply-add) extension.
 @item faminmax
 Enable the Floating Point Absolute Maximum/Minimum extension.
 
diff --git a/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c b/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c
new file mode 100644
index 00000000000..ea21856fa62
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c
@@ -0,0 +1,221 @@ 
+/* { dg-do compile } */
+/* { dg-additional-options "-O3 -march=armv9-a+fp8fma" } */
+/* { dg-final { check-function-bodies "**" "" } } */
+
+#include "arm_neon.h"
+
+/*
+** test_vmlalbq_f16_fpm:
+**	msr	fpmr, x0
+**	fmlalb	v0.8h, v1.16b, v2.16b
+**	ret
+*/
+float16x8_t
+test_vmlalbq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalbq_f16_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlaltq_f16_fpm:
+**	msr	fpmr, x0
+**	fmlalt	v0.8h, v1.16b, v2.16b
+**	ret
+*/
+float16x8_t
+test_vmlaltq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlaltq_f16_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlallbbq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlallbb	v0.4s, v1.16b, v2.16b
+**	ret
+*/
+float32x4_t
+test_vmlallbbq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbbq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlallbtq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlallbt	v0.4s, v1.16b, v2.16b
+**	ret
+*/
+float32x4_t
+test_vmlallbtq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbtq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlalltbq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlalltb	v0.4s, v1.16b, v2.16b
+**	ret
+*/
+float32x4_t
+test_vmlalltbq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalltbq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlallttq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlalltt	v0.4s, v1.16b, v2.16b
+**	ret
+*/
+float32x4_t
+test_vmlallttq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallttq_f32_mf8_fpm (a, b, c, d);
+}
+
+/*
+** test_vmlalbq_lane_f16_fpm:
+**	msr	fpmr, x0
+**	fmlalb	v0.8h, v1.16b, v2.b\[1\]
+**	ret
+*/
+float16x8_t
+test_vmlalbq_lane_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlalbq_lane_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlalbq_laneq_f16_fpm:
+**	msr	fpmr, x0
+**	fmlalb	v0.8h, v1.16b, v2.b\[1\]
+**	ret
+*/
+float16x8_t
+test_vmlalbq_laneq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalbq_laneq_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlaltq_lane_f16_fpm:
+**	msr	fpmr, x0
+**	fmlalt	v0.8h, v1.16b, v2.b\[1\]
+**	ret
+*/
+float16x8_t
+test_vmlaltq_lane_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlaltq_lane_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlaltq_laneq_f16_fpm:
+**	msr	fpmr, x0
+**	fmlalt	v0.8h, v1.16b, v2.b\[1\]
+**	ret
+*/
+float16x8_t
+test_vmlaltq_laneq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlaltq_laneq_f16_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbbq_lane_f32_fpm:
+**	msr	fpmr, x0
+**	fmlallbb	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlallbbq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlallbbq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbbq_laneq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlallbb	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlallbbq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbbq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbtq_lane_f32_fpm:
+**	msr	fpmr, x0
+**	fmlallbt	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlallbtq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlallbtq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallbtq_laneq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlallbt	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlallbtq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallbtq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlalltbq_lane_f32_fpm:
+**	msr	fpmr, x0
+**	fmlalltb	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlalltbq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlalltbq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlalltbq_laneq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlalltb	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlalltbq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlalltbq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallttq_lane_f32_fpm:
+**	msr	fpmr, x0
+**	fmlalltt	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlallttq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d)
+{
+  return vmlallttq_lane_f32_mf8_fpm (a, b, c, 1, d);
+}
+
+/*
+** test_vmlallttq_laneq_f32_fpm:
+**	msr	fpmr, x0
+**	fmlalltt	v0.4s, v1.16b, v2.b\[1\]
+**	ret
+*/
+float32x4_t
+test_vmlallttq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d)
+{
+  return vmlallttq_laneq_f32_mf8_fpm (a, b, c, 1, d);
+}