diff mbox series

[7/8] target/arm: Implement FPCR.EBF=1 semantics for bfdotadd()

Message ID 20240730160306.2959745-8-peter.maydell@linaro.org
State New
Headers show
Series target/arm: Implement FEAT_EBF16 | expand

Commit Message

Peter Maydell July 30, 2024, 4:03 p.m. UTC
Implement the FPCR.EBF=1 semantics for bfdotadd() operations:
 * is_ebf() sets up fpst and fpst_odd
 * bfdotadd_ebf() implements the fused paired-multiply-and-add
   operation that we need

The paired-multiply-and-add is similar to f16_dotadd() and
we use the same trick here as in that function, but the inputs
here are bfloat16 rather than float16.

Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
---
 target/arm/tcg/vec_helper.c | 57 +++++++++++++++++++++++++++++++++++--
 1 file changed, 54 insertions(+), 3 deletions(-)

Comments

Richard Henderson July 31, 2024, 1:50 a.m. UTC | #1
On 7/31/24 02:03, Peter Maydell wrote:
> Implement the FPCR.EBF=1 semantics for bfdotadd() operations:
>   * is_ebf() sets up fpst and fpst_odd
>   * bfdotadd_ebf() implements the fused paired-multiply-and-add
>     operation that we need
> 
> The paired-multiply-and-add is similar to f16_dotadd() and
> we use the same trick here as in that function, but the inputs
> here are bfloat16 rather than float16.
> 
> Signed-off-by: Peter Maydell<peter.maydell@linaro.org>
> ---
>   target/arm/tcg/vec_helper.c | 57 +++++++++++++++++++++++++++++++++++--
>   1 file changed, 54 insertions(+), 3 deletions(-)

Reviewed-by: Richard Henderson <richard.henderson@linaro.org>

r~
diff mbox series

Patch

diff --git a/target/arm/tcg/vec_helper.c b/target/arm/tcg/vec_helper.c
index baf04a0561b..64076c1c595 100644
--- a/target/arm/tcg/vec_helper.c
+++ b/target/arm/tcg/vec_helper.c
@@ -2792,7 +2792,20 @@  DO_MMLA_B(gvec_usmmla_b, do_usmmla_b)
 
 bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp)
 {
-    /* FPCR is ignored for BFDOT and BFMMLA. */
+    /*
+     * For BFDOT, BFMMLA, etc, the behaviour depends on FPCR.EBF.
+     * For EBF = 0, we ignore the FPCR bits which determine rounding
+     * mode and denormal-flushing, and we do unfused multiplies and
+     * additions with intermediate rounding of all products and sums.
+     * For EBF = 1, we honour FPCR rounding mode and denormal-flushing bits,
+     * and we perform a fused two-way sum-of-products without intermediate
+     * rounding of the products.
+     * In either case, we don't set fp exception flags.
+     *
+     * EBF is AArch64 only, so even if it's set in the FPCR it has
+     * no effect on AArch32 instructions.
+     */
+    bool ebf = is_a64(env) && env->vfp.fpcr & FPCR_EBF;
     float_status bf_status = {
         .tininess_before_rounding = float_tininess_before_rounding,
         .float_rounding_mode = float_round_to_odd_inf,
@@ -2801,8 +2814,19 @@  bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp)
         .default_nan_mode = true,
     };
 
+    if (ebf) {
+        float_status *fpst = &env->vfp.fp_status;
+        set_flush_to_zero(get_flush_to_zero(fpst), &bf_status);
+        set_flush_inputs_to_zero(get_flush_inputs_to_zero(fpst), &bf_status);
+        set_float_rounding_mode(get_float_rounding_mode(fpst), &bf_status);
+
+        /* EBF=1 needs to do a step with round-to-odd semantics */
+        *oddstatusp = bf_status;
+        set_float_rounding_mode(float_round_to_odd, oddstatusp);
+    }
+
     *statusp = bf_status;
-    return false;
+    return ebf;
 }
 
 float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst)
@@ -2824,7 +2848,34 @@  float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst)
 float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2,
                      float_status *fpst, float_status *fpst_odd)
 {
-    g_assert_not_reached();
+    /*
+     * Compare f16_dotadd() in sme_helper.c, but here we have
+     * bfloat16 inputs. In particular that means that we do not
+     * want the FPCR.FZ16 flush semantics, so we use the normal
+     * float_status for the input handling here.
+     */
+    float64 e1r = float32_to_float64(e1 << 16, fpst);
+    float64 e1c = float32_to_float64(e1 & 0xffff0000u, fpst);
+    float64 e2r = float32_to_float64(e2 << 16, fpst);
+    float64 e2c = float32_to_float64(e2 & 0xffff0000u, fpst);
+    float64 t64;
+    float32 t32;
+
+    /*
+     * The ARM pseudocode function FPDot performs both multiplies
+     * and the add with a single rounding operation.  Emulate this
+     * by performing the first multiply in round-to-odd, then doing
+     * the second multiply as fused multiply-add, and rounding to
+     * float32 all in one step.
+     */
+    t64 = float64_mul(e1r, e2r, fpst_odd);
+    t64 = float64r32_muladd(e1c, e2c, t64, 0, fpst);
+
+    /* This conversion is exact, because we've already rounded. */
+    t32 = float64_to_float32(t64, fpst);
+
+    /* The final accumulation step is not fused. */
+    return float32_add(sum, t32, fpst);
 }
 
 void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va,