diff mbox series

i386: Utilize VCOMSBF16 for BF16 Comparisons with AVX10.2

Message ID 20241017021950.756964-1-admin@levyhsu.com
State New
Headers show
Series i386: Utilize VCOMSBF16 for BF16 Comparisons with AVX10.2 | expand

Commit Message

Levy Hsu Oct. 17, 2024, 2:19 a.m. UTC
Bootstrapped and regtested on x86_64-pc-linux-gnu{-m64}.
Ok for trunk?

This patch enables the use of the VCOMSBF16 instruction from AVX10.2 for
efficient BF16 comparisons.

gcc/ChangeLog:

	* config/i386/i386-expand.cc (ix86_expand_branch): Handle BFmode
	when TARGET_AVX10_2_256 is enabled.
	(ix86_prepare_fp_compare_args):
	Renamed SSE_FLOAT_MODE_SSEMATH_OR_HF_P to SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P.
	(ix86_expand_fp_compare): For BFmode with IX86_FPCMP_COMI, use cmpibf.
	(ix86_expand_fp_movcc):
	Renamed SSE_FLOAT_MODE_SSEMATH_OR_HF_P to SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P.
	* config/i386/i386.cc (ix86_multiplication_cost): Ditto.
	(ix86_division_cost): Ditto.
	(ix86_rtx_costs): Ditto.
	(ix86_vector_costs::add_stmt_cost): Ditto.
	* config/i386/i386.h (SSE_FLOAT_MODE_SSEMATH_OR_HF_P):  Ditto.
	(SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P): Add BFmode.
	* config/i386/i386.md (*cmpibf): New insn for cmpibf.

gcc/testsuite/ChangeLog:

	* gcc.target/i386/avx10_2-comibf-1.c: New test.
	* gcc.target/i386/avx10_2-comibf-2.c: New test.
---
 gcc/config/i386/i386-expand.cc                |  22 ++--
 gcc/config/i386/i386.cc                       |  22 ++--
 gcc/config/i386/i386.h                        |   7 +-
 gcc/config/i386/i386.md                       |  27 +++-
 .../gcc.target/i386/avx10_2-comibf-1.c        |  40 ++++++
 .../gcc.target/i386/avx10_2-comibf-2.c        | 115 ++++++++++++++++++
 6 files changed, 208 insertions(+), 25 deletions(-)
 create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c
 create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c
diff mbox series

Patch

diff --git a/gcc/config/i386/i386-expand.cc b/gcc/config/i386/i386-expand.cc
index 63f5e348d64..ce413fa0eba 100644
--- a/gcc/config/i386/i386-expand.cc
+++ b/gcc/config/i386/i386-expand.cc
@@ -2530,6 +2530,10 @@  ix86_expand_branch (enum rtx_code code, rtx op0, rtx op1, rtx label)
       emit_jump_insn (gen_rtx_SET (pc_rtx, tmp));
       return;
 
+    case E_BFmode:
+      gcc_assert (TARGET_AVX10_2_256 && !flag_trapping_math);
+      goto simple;
+
     case E_DImode:
       if (TARGET_64BIT)
 	goto simple;
@@ -2796,9 +2800,9 @@  ix86_prepare_fp_compare_args (enum rtx_code code, rtx *pop0, rtx *pop1)
   bool unordered_compare = ix86_unordered_fp_compare (code);
   rtx op0 = *pop0, op1 = *pop1;
   machine_mode op_mode = GET_MODE (op0);
-  bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HF_P (op_mode);
+  bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (op_mode);
 
-  if (op_mode == BFmode)
+  if (op_mode == BFmode && (!TARGET_AVX10_2_256 || flag_trapping_math))
     {
       rtx op = gen_lowpart (HImode, op0);
       if (CONST_INT_P (op))
@@ -2917,10 +2921,14 @@  ix86_expand_fp_compare (enum rtx_code code, rtx op0, rtx op1)
     {
     case IX86_FPCMP_COMI:
       tmp = gen_rtx_COMPARE (CCFPmode, op0, op1);
-      if (TARGET_AVX10_2_256 && (code == EQ || code == NE))
-	tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_OPTCOMX);
-      if (unordered_compare)
-	tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_NOTRAP);
+      /* We only have vcomsbf16, No vcomubf16 nor vcomxbf16 */
+      if (GET_MODE (op0) != E_BFmode)
+        {
+	  if (TARGET_AVX10_2_256 && (code == EQ || code == NE))
+	    tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_OPTCOMX);
+	  if (unordered_compare)
+	    tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_NOTRAP);
+	}
       cmp_mode = CCFPmode;
       emit_insn (gen_rtx_SET (gen_rtx_REG (CCFPmode, FLAGS_REG), tmp));
       break;
@@ -4635,7 +4643,7 @@  ix86_expand_fp_movcc (rtx operands[])
       && !ix86_fp_comparison_operator (operands[1], VOIDmode))
     return false;
 
-  if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+  if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
     {
       machine_mode cmode;
 
diff --git a/gcc/config/i386/i386.cc b/gcc/config/i386/i386.cc
index a1f0ae7a7e1..c7132252e48 100644
--- a/gcc/config/i386/i386.cc
+++ b/gcc/config/i386/i386.cc
@@ -21324,7 +21324,7 @@  ix86_multiplication_cost (const struct processor_costs *cost,
   if (VECTOR_MODE_P (mode))
     inner_mode = GET_MODE_INNER (mode);
 
-  if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+  if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
     return inner_mode == DFmode ? cost->mulsd : cost->mulss;
   else if (X87_FLOAT_MODE_P (mode))
     return cost->fmul;
@@ -21449,7 +21449,7 @@  ix86_division_cost (const struct processor_costs *cost,
   if (VECTOR_MODE_P (mode))
     inner_mode = GET_MODE_INNER (mode);
 
-  if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+  if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
     return inner_mode == DFmode ? cost->divsd : cost->divss;
   else if (X87_FLOAT_MODE_P (mode))
     return cost->fdiv;
@@ -21991,7 +21991,7 @@  ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
 	  return true;
 	}
 
-      if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+      if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	*total = cost->addss;
       else if (X87_FLOAT_MODE_P (mode))
 	*total = cost->fadd;
@@ -22198,7 +22198,7 @@  ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
       return false;
 
     case NEG:
-      if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+      if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	*total = cost->sse_op;
       else if (X87_FLOAT_MODE_P (mode))
 	*total = cost->fchs;
@@ -22306,14 +22306,14 @@  ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
       return false;
 
     case FLOAT_EXTEND:
-      if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+      if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	*total = 0;
       else
         *total = ix86_vec_cost (mode, cost->addss);
       return false;
 
     case FLOAT_TRUNCATE:
-      if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+      if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	*total = cost->fadd;
       else
         *total = ix86_vec_cost (mode, cost->addss);
@@ -22323,7 +22323,7 @@  ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
       /* SSE requires memory load for the constant operand. It may make
 	 sense to account for this.  Of course the constant operand may or
 	 may not be reused. */
-      if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+      if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	*total = cost->sse_op;
       else if (X87_FLOAT_MODE_P (mode))
 	*total = cost->fabs;
@@ -22334,7 +22334,7 @@  ix86_rtx_costs (rtx x, machine_mode mode, int outer_code_i, int opno,
       return false;
 
     case SQRT:
-      if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+      if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	*total = mode == SFmode ? cost->sqrtss : cost->sqrtsd;
       else if (X87_FLOAT_MODE_P (mode))
 	*total = cost->fsqrt;
@@ -25083,7 +25083,7 @@  ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
 	case MINUS_EXPR:
 	  if (kind == scalar_stmt)
 	    {
-	      if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+	      if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 		stmt_cost = ix86_cost->addss;
 	      else if (X87_FLOAT_MODE_P (mode))
 		stmt_cost = ix86_cost->fadd;
@@ -25109,7 +25109,7 @@  ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
 	  break;
 
 	case NEGATE_EXPR:
-	  if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+	  if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	    stmt_cost = ix86_cost->sse_op;
 	  else if (X87_FLOAT_MODE_P (mode))
 	    stmt_cost = ix86_cost->fchs;
@@ -25165,7 +25165,7 @@  ix86_vector_costs::add_stmt_cost (int count, vect_cost_for_stmt kind,
 	case BIT_XOR_EXPR:
 	case BIT_AND_EXPR:
 	case BIT_NOT_EXPR:
-	  if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode))
+	  if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode))
 	    stmt_cost = ix86_cost->sse_op;
 	  else if (VECTOR_MODE_P (mode))
 	    stmt_cost = ix86_vec_cost (mode, ix86_cost->sse_op);
diff --git a/gcc/config/i386/i386.h b/gcc/config/i386/i386.h
index f5204aa1ed2..d56a23e2b7b 100644
--- a/gcc/config/i386/i386.h
+++ b/gcc/config/i386/i386.h
@@ -1158,9 +1158,10 @@  extern const char *host_detect_local_cpu (int argc, const char **argv);
 #define SSE_FLOAT_MODE_P(MODE) \
   ((TARGET_SSE && (MODE) == SFmode) || (TARGET_SSE2 && (MODE) == DFmode))
 
-#define SSE_FLOAT_MODE_SSEMATH_OR_HF_P(MODE)				\
-  ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH)				\
-   || (TARGET_AVX512FP16 && (MODE) == HFmode))
+#define SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P(MODE)                          \
+  ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH)                         \
+   || (TARGET_AVX512FP16 && (MODE) == HFmode)                           \
+   || (TARGET_AVX10_2_256 && (MODE) == BFmode))
 
 #define FMA4_VEC_FLOAT_MODE_P(MODE) \
   (TARGET_FMA4 && ((MODE) == V4SFmode || (MODE) == V2DFmode \
diff --git a/gcc/config/i386/i386.md b/gcc/config/i386/i386.md
index e4d1c56ea54..dce21e9962e 100644
--- a/gcc/config/i386/i386.md
+++ b/gcc/config/i386/i386.md
@@ -1814,14 +1814,22 @@ 
 	      (pc)))]
   "TARGET_80387 || (SSE_FLOAT_MODE_P (SFmode) && TARGET_SSE_MATH)"
 {
-  rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]);
-  rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]);
-  do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0,
+  if (TARGET_AVX10_2_256 && !flag_trapping_math)
+    {
+      ix86_expand_branch (GET_CODE (operands[0]),
+		      operands[1], operands[2], operands[3]);
+    }
+  else
+    {
+      rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]);
+      rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]);
+      do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0,
 			   SFmode, NULL_RTX, NULL,
 			   as_a <rtx_code_label *> (operands[3]),
 			   /* Unfortunately this isn't propagated.  */
 			   profile_probability::even ());
-  DONE;
+    }
+    DONE;
 })
 
 (define_expand "cstorehf4"
@@ -2096,6 +2104,17 @@ 
    (set_attr "prefix" "evex")
    (set_attr "mode" "HF")])
 
+(define_insn "*cmpibf"
+  [(set (reg:CCFP FLAGS_REG)
+	(compare:CCFP
+	  (match_operand:BF 0 "register_operand" "v")
+	  (match_operand:BF 1 "nonimmediate_operand" "vm")))]
+  "TARGET_AVX10_2_256"
+  "vcomsbf16\t{%1, %0|%0, %1}"
+  [(set_attr "type" "ssecomi")
+   (set_attr "prefix" "evex")
+   (set_attr "mode" "BF")])
+
 ;; Set carry flag.
 (define_insn "x86_stc"
   [(set (reg:CCC FLAGS_REG) (unspec:CCC [(const_int 0)] UNSPEC_STC))]
diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c
new file mode 100644
index 00000000000..85b773b89f2
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c
@@ -0,0 +1,40 @@ 
+/* { dg-do compile } */
+/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */
+/* { dg-final { scan-assembler-times "vcomsbf16\[ \\t\]+\[^{}\n\]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 6 } } */
+/* { dg-final { scan-assembler-times {j[a-z]+\s} 6 } } */
+
+__bf16
+foo_eq (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+  return a == b ? c + d : c - d;
+}
+
+__bf16
+foo_ne (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+  return a != b ? c + d : c - d;
+}
+
+__bf16
+foo_lt (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+  return a < b ? c + d : c - d;
+}
+
+__bf16
+foo_le (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+  return a <= b ? c + d : c - d;
+}
+
+__bf16
+foo_gt (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+  return a > b ? c + d : c - d;
+}
+
+__bf16
+foo_ge (__bf16 a, __bf16 b, __bf16 c, __bf16 d)
+{
+  return a >= b ? c + d : c - d;
+}
diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c
new file mode 100644
index 00000000000..f53ce6b18a8
--- /dev/null
+++ b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c
@@ -0,0 +1,115 @@ 
+ /* { dg-do run } */
+/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */
+
+#include <stdlib.h>
+#include <stdint.h>
+#include <string.h>
+
+/* Fast shift conversion here for convenience */
+static __bf16
+float_to_bf16 (float f)
+{
+  uint32_t float_bits;
+  uint16_t bf16_bits;
+
+  memcpy (&float_bits, &f, sizeof (float_bits));
+  bf16_bits = (uint16_t) (float_bits >> 16);
+
+  __bf16 bf;
+  memcpy (&bf, &bf16_bits, sizeof (bf));
+  return bf;
+}
+
+static float
+bf16_to_float (__bf16 bf)
+{
+  uint32_t float_bits;
+  uint16_t bf16_bits;
+
+  memcpy (&bf16_bits, &bf, sizeof (bf16_bits));
+  float_bits = ((uint32_t) bf16_bits) << 16;
+
+  float f;
+  memcpy (&f, &float_bits, sizeof (f));
+  return f;
+}
+
+static void
+test_eq (__bf16 a, __bf16 b)
+{
+  int result = (a == b);
+  int expected = (bf16_to_float (a) == bf16_to_float (b));
+  if (result != expected)
+    abort ();
+}
+
+static void
+test_ne (__bf16 a, __bf16 b)
+{
+  int result = (a != b);
+  int expected = (bf16_to_float (a) != bf16_to_float (b));
+  if (result != expected)
+    abort ();
+}
+
+static void
+test_lt (__bf16 a, __bf16 b)
+{
+  int result = (a < b);
+  int expected = (bf16_to_float (a) < bf16_to_float (b));
+  if (result != expected)
+    abort ();
+}
+
+static void
+test_le (__bf16 a, __bf16 b)
+{
+  int result = (a <= b);
+  int expected = (bf16_to_float (a) <= bf16_to_float (b));
+  if (result != expected)
+    abort ();
+}
+
+static void
+test_gt (__bf16 a, __bf16 b)
+{
+  int result = (a > b);
+  int expected = (bf16_to_float (a) > bf16_to_float (b));
+  if (result != expected)
+    abort ();
+}
+
+static void
+test_ge (__bf16 a, __bf16 b)
+{
+  int result = (a >= b);
+  int expected = (bf16_to_float (a) >= bf16_to_float (b));
+  if (result != expected)
+    abort ();
+}
+
+int
+main (void)
+{
+  float test_values[] = {
+    -10.0f, -1.0f, -0.5f, 0.0f, 0.5f, 1.0f, 10.0f, 100.0f, -100.0f
+  };
+
+  size_t num_values = sizeof (test_values) / sizeof (test_values[0]);
+
+  for (size_t i = 0; i < num_values; i++)
+      for (size_t j = 0; j < num_values; j++)
+        {
+          __bf16 a = float_to_bf16 (test_values[i]);
+          __bf16 b = float_to_bf16 (test_values[j]);
+
+          test_eq (a, b);
+          test_ne (a, b);
+          test_lt (a, b);
+          test_le (a, b);
+          test_gt (a, b);
+          test_ge (a, b);
+        }
+
+  return 0;
+}