diff mbox series

target/arm/tcg: Fix overflow in matrix-multiply accumulate

Message ID 20240811054341.745674-1-joe@pf.is.s.u-tokyo.ac.jp
State New
Headers show
Series target/arm/tcg: Fix overflow in matrix-multiply accumulate | expand

Commit Message

Joe Hattori Aug. 11, 2024, 5:43 a.m. UTC
Arm's intrinsic matrix multiply accumulate instructions take two 8-bit
vector and add up a 32-bit vector. Current emulation causes overflow
when large 8-bit integers are used. This commit fixes the issue by
casting the 8-bit integers to 32-bit integers before multiplication.

Fixes: 2323c5ffd4b5 ("target/arm: Implement integer matrix multiply accumulate")
Signed-off-by: Joe Hattori <joe@pf.is.s.u-tokyo.ac.jp>
---
 target/arm/tcg/vec_helper.c | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

Comments

Richard Henderson Aug. 11, 2024, 9:42 p.m. UTC | #1
On 8/11/24 15:43, Joe Hattori wrote:
> Arm's intrinsic matrix multiply accumulate instructions take two 8-bit
> vector and add up a 32-bit vector. Current emulation causes overflow
> when large 8-bit integers are used. This commit fixes the issue by
> casting the 8-bit integers to 32-bit integers before multiplication.

"Large 8-bit integers"?

0xff * 0xff = 0xfe01.

This in no way overflows "int" on any supported host, which is the type we get via normal 
C arithmetic promotion rules.

So what is this supposed to fix?


r~

> 
> Fixes: 2323c5ffd4b5 ("target/arm: Implement integer matrix multiply accumulate")
> Signed-off-by: Joe Hattori <joe@pf.is.s.u-tokyo.ac.jp>
> ---
>   target/arm/tcg/vec_helper.c | 6 +++---
>   1 file changed, 3 insertions(+), 3 deletions(-)
> 
> diff --git a/target/arm/tcg/vec_helper.c b/target/arm/tcg/vec_helper.c
> index 98604d170fd3..e9c33520232a 100644
> --- a/target/arm/tcg/vec_helper.c
> +++ b/target/arm/tcg/vec_helper.c
> @@ -2718,7 +2718,7 @@ static uint32_t do_smmla_b(uint32_t sum, void *vn, void *vm)
>       int8_t *n = vn, *m = vm;
>   
>       for (intptr_t k = 0; k < 8; ++k) {
> -        sum += n[H1(k)] * m[H1(k)];
> +        sum += (uint32_t)n[H1(k)] * (uint32_t)m[H1(k)];
>       }
>       return sum;
>   }
> @@ -2728,7 +2728,7 @@ static uint32_t do_ummla_b(uint32_t sum, void *vn, void *vm)
>       uint8_t *n = vn, *m = vm;
>   
>       for (intptr_t k = 0; k < 8; ++k) {
> -        sum += n[H1(k)] * m[H1(k)];
> +        sum += (uint32_t)n[H1(k)] * (uint32_t)m[H1(k)];
>       }
>       return sum;
>   }
> @@ -2739,7 +2739,7 @@ static uint32_t do_usmmla_b(uint32_t sum, void *vn, void *vm)
>       int8_t *m = vm;
>   
>       for (intptr_t k = 0; k < 8; ++k) {
> -        sum += n[H1(k)] * m[H1(k)];
> +        sum += (uint32_t)n[H1(k)] * (uint32_t)m[H1(k)];
>       }
>       return sum;
>   }
diff mbox series

Patch

diff --git a/target/arm/tcg/vec_helper.c b/target/arm/tcg/vec_helper.c
index 98604d170fd3..e9c33520232a 100644
--- a/target/arm/tcg/vec_helper.c
+++ b/target/arm/tcg/vec_helper.c
@@ -2718,7 +2718,7 @@  static uint32_t do_smmla_b(uint32_t sum, void *vn, void *vm)
     int8_t *n = vn, *m = vm;
 
     for (intptr_t k = 0; k < 8; ++k) {
-        sum += n[H1(k)] * m[H1(k)];
+        sum += (uint32_t)n[H1(k)] * (uint32_t)m[H1(k)];
     }
     return sum;
 }
@@ -2728,7 +2728,7 @@  static uint32_t do_ummla_b(uint32_t sum, void *vn, void *vm)
     uint8_t *n = vn, *m = vm;
 
     for (intptr_t k = 0; k < 8; ++k) {
-        sum += n[H1(k)] * m[H1(k)];
+        sum += (uint32_t)n[H1(k)] * (uint32_t)m[H1(k)];
     }
     return sum;
 }
@@ -2739,7 +2739,7 @@  static uint32_t do_usmmla_b(uint32_t sum, void *vn, void *vm)
     int8_t *m = vm;
 
     for (intptr_t k = 0; k < 8; ++k) {
-        sum += n[H1(k)] * m[H1(k)];
+        sum += (uint32_t)n[H1(k)] * (uint32_t)m[H1(k)];
     }
     return sum;
 }