diff mbox series

x86: Fix unchecked AVX512-VBMI2 usage in strrchr-evex-base.S

Message ID 20231112205120.3672636-1-goldstein.w.n@gmail.com
State New
Headers show
Series x86: Fix unchecked AVX512-VBMI2 usage in strrchr-evex-base.S | expand

Commit Message

Noah Goldstein Nov. 12, 2023, 8:51 p.m. UTC
strrchr-evex-base used `vpcompress{b|d}` in the page cross logic but
was missing the CPU_FEATURE checks for VBMI2 in the
ifunc/ifunc-impl-list.

The fix is either to add those checks or change the logic to not use
`vpcompress{b|d}`. Choosing the latter here so that the strrchr-evex
implementation is usable on SKX.

New implementation is a bit slower, but this is in a cold path so its
probably okay.
---
 sysdeps/x86_64/multiarch/strrchr-evex-base.S | 75 +++++++++++++-------
 1 file changed, 51 insertions(+), 24 deletions(-)

Comments

Sunil Pandey Nov. 13, 2023, 2:07 p.m. UTC | #1
On Sun, Nov 12, 2023 at 12:51 PM Noah Goldstein <goldstein.w.n@gmail.com>
wrote:

> strrchr-evex-base used `vpcompress{b|d}` in the page cross logic but
> was missing the CPU_FEATURE checks for VBMI2 in the
> ifunc/ifunc-impl-list.
>
> The fix is either to add those checks or change the logic to not use
> `vpcompress{b|d}`. Choosing the latter here so that the strrchr-evex
> implementation is usable on SKX.
>
> New implementation is a bit slower, but this is in a cold path so its
> probably okay.
> ---
>  sysdeps/x86_64/multiarch/strrchr-evex-base.S | 75 +++++++++++++-------
>  1 file changed, 51 insertions(+), 24 deletions(-)
>
> diff --git a/sysdeps/x86_64/multiarch/strrchr-evex-base.S
> b/sysdeps/x86_64/multiarch/strrchr-evex-base.S
> index cd6a0a870a..da5bf1237a 100644
> --- a/sysdeps/x86_64/multiarch/strrchr-evex-base.S
> +++ b/sysdeps/x86_64/multiarch/strrchr-evex-base.S
> @@ -35,18 +35,20 @@
>  #  define CHAR_SIZE    4
>  #  define VPCMP                vpcmpd
>  #  define VPMIN                vpminud
> -#  define VPCOMPRESS   vpcompressd
>  #  define VPTESTN      vptestnmd
>  #  define VPTEST       vptestmd
>  #  define VPBROADCAST  vpbroadcastd
>  #  define VPCMPEQ      vpcmpeqd
>
>  # else
> -#  define SHIFT_REG    VRDI
> +#  if VEC_SIZE == 64
> +#   define SHIFT_REG   VRCX
> +#  else
> +#   define SHIFT_REG   VRDI
> +#  endif
>  #  define CHAR_SIZE    1
>  #  define VPCMP                vpcmpb
>  #  define VPMIN                vpminub
> -#  define VPCOMPRESS   vpcompressb
>  #  define VPTESTN      vptestnmb
>  #  define VPTEST       vptestmb
>  #  define VPBROADCAST  vpbroadcastb
> @@ -56,6 +58,12 @@
>  #  define KORTEST_M    KORTEST
>  # endif
>
> +# if VEC_SIZE == 32 || (defined USE_AS_WCSRCHR)
> +#  define SHIFT_R(cnt, val)    shrx cnt, val, val
> +# else
> +#  define SHIFT_R(cnt, val)    shr %cl, val
> +# endif
> +
>  # define VMATCH                VMM(0)
>  # define CHAR_PER_VEC  (VEC_SIZE / CHAR_SIZE)
>  # define PAGE_SIZE     4096
> @@ -71,7 +79,7 @@ ENTRY_P2ALIGN(STRRCHR, 6)
>         andl    $(PAGE_SIZE - 1), %eax
>         cmpl    $(PAGE_SIZE - VEC_SIZE), %eax
>         jg      L(cross_page_boundary)
> -
> +L(page_cross_continue):
>         VMOVU   (%rdi), %VMM(1)
>         /* k0 has a 1 for each zero CHAR in YMM1.  */
>         VPTESTN %VMM(1), %VMM(1), %k0
> @@ -79,7 +87,7 @@ ENTRY_P2ALIGN(STRRCHR, 6)
>         test    %VGPR(rsi), %VGPR(rsi)
>         jz      L(aligned_more)
>         /* fallthrough: zero CHAR in first VEC.  */
> -L(page_cross_return):
> +
>         /* K1 has a 1 for each search CHAR match in VEC(1).  */
>         VPCMPEQ %VMATCH, %VMM(1), %k1
>         KMOV    %k1, %VGPR(rax)
> @@ -167,7 +175,6 @@ L(first_vec_x1_return):
>
>         .p2align 4,, 12
>  L(aligned_more):
> -L(page_cross_continue):
>         /* Need to keep original pointer incase VEC(1) has last match.  */
>         movq    %rdi, %rsi
>         andq    $-VEC_SIZE, %rdi
> @@ -340,34 +347,54 @@ L(return_new_match_ret):
>         leaq    (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
>         ret
>
> -       .p2align 4,, 4
>  L(cross_page_boundary):
> +       /* eax contains all the page offset bits of src (rdi). `xor rdi,
> +          rax` sets pointer will all page offset bits cleared so
> +          offset of (PAGE_SIZE - VEC_SIZE) will get last aligned VEC
> +          before page cross (guaranteed to be safe to read). Doing this
> +          as opposed to `movq %rdi, %rax; andq $-VEC_SIZE, %rax` saves
> +          a bit of code size.  */
>         xorq    %rdi, %rax
> -       mov     $-1, %VRDX
> -       VMOVU   (PAGE_SIZE - VEC_SIZE)(%rax), %VMM(6)
> -       VPTESTN %VMM(6), %VMM(6), %k0
> +       VMOVU   (PAGE_SIZE - VEC_SIZE)(%rax), %VMM(1)
> +       VPTESTN %VMM(1), %VMM(1), %k0
>         KMOV    %k0, %VRSI
>
> -# ifdef USE_AS_WCSRCHR
> +       /* Shift out zero CHAR matches that are before the beginning of
> +          src (rdi).  */
> +# if VEC_SIZE == 64 || (defined USE_AS_WCSRCHR)
>         movl    %edi, %ecx
> -       and     $(VEC_SIZE - 1), %ecx
> +# endif
> +# ifdef USE_AS_WCSRCHR
> +       andl    $(VEC_SIZE - 1), %ecx
>         shrl    $2, %ecx
>  # endif
> -       shlx    %SHIFT_REG, %VRDX, %VRDX
> +       SHIFT_R (%SHIFT_REG, %VRSI)
> +# if VEC_SIZE == 32 || (defined USE_AS_WCSRCHR)
> +       /* For strrchr-evex512 we use SHIFT_R as shr which will set zero
> +          flag.  */
> +       test    %VRSI, %VRSI
> +# endif
> +       jz      L(page_cross_continue)
>
> +       /* Found zero CHAR so need to test for search CHAR.  */
> +       VPCMPEQ %VMATCH, %VMM(1), %k1
> +       KMOV    %k1, %VRAX
> +       /* Shift out search CHAR matches that are before the beginning of
> +          src (rdi).  */
> +       SHIFT_R (%SHIFT_REG, %VRAX)
> +       /* Check if any search CHAR match in range.  */
> +       blsmsk  %VRSI, %VRSI
> +       and     %VRSI, %VRAX
> +       jz      L(ret2)
> +       bsr     %VRAX, %VRAX
>  # ifdef USE_AS_WCSRCHR
> -       kmovw   %edx, %k1
> +       leaq    (%rdi, %rax, CHAR_SIZE), %rax
>  # else
> -       KMOV    %VRDX, %k1
> +       addq    %rdi, %rax
>  # endif
> -
> -       VPCOMPRESS %VMM(6), %VMM(1){%k1}{z}
> -       /* We could technically just jmp back after the vpcompress but
> -          it doesn't save any 16-byte blocks.  */
> -       shrx    %SHIFT_REG, %VRSI, %VRSI
> -       test    %VRSI, %VRSI
> -       jnz     L(page_cross_return)
> -       jmp     L(page_cross_continue)
> -       /* 1-byte from cache line.  */
> +L(ret2):
> +       ret
> +       /* 3 bytes from cache-line for evex.  */
> +       /* 0 bytes from cache-line for evex512.  */
>  END(STRRCHR)
>  #endif
> --
> 2.34.1


LGTM
Reviewed-by: Sunil K Pandey <skpgkp2@gmail.com>
diff mbox series

Patch

diff --git a/sysdeps/x86_64/multiarch/strrchr-evex-base.S b/sysdeps/x86_64/multiarch/strrchr-evex-base.S
index cd6a0a870a..da5bf1237a 100644
--- a/sysdeps/x86_64/multiarch/strrchr-evex-base.S
+++ b/sysdeps/x86_64/multiarch/strrchr-evex-base.S
@@ -35,18 +35,20 @@ 
 #  define CHAR_SIZE	4
 #  define VPCMP		vpcmpd
 #  define VPMIN		vpminud
-#  define VPCOMPRESS	vpcompressd
 #  define VPTESTN	vptestnmd
 #  define VPTEST	vptestmd
 #  define VPBROADCAST	vpbroadcastd
 #  define VPCMPEQ	vpcmpeqd
 
 # else
-#  define SHIFT_REG	VRDI
+#  if VEC_SIZE == 64
+#   define SHIFT_REG	VRCX
+#  else
+#   define SHIFT_REG	VRDI
+#  endif
 #  define CHAR_SIZE	1
 #  define VPCMP		vpcmpb
 #  define VPMIN		vpminub
-#  define VPCOMPRESS	vpcompressb
 #  define VPTESTN	vptestnmb
 #  define VPTEST	vptestmb
 #  define VPBROADCAST	vpbroadcastb
@@ -56,6 +58,12 @@ 
 #  define KORTEST_M	KORTEST
 # endif
 
+# if VEC_SIZE == 32 || (defined USE_AS_WCSRCHR)
+#  define SHIFT_R(cnt, val)	shrx cnt, val, val
+# else
+#  define SHIFT_R(cnt, val)	shr %cl, val
+# endif
+
 # define VMATCH		VMM(0)
 # define CHAR_PER_VEC	(VEC_SIZE / CHAR_SIZE)
 # define PAGE_SIZE	4096
@@ -71,7 +79,7 @@  ENTRY_P2ALIGN(STRRCHR, 6)
 	andl	$(PAGE_SIZE - 1), %eax
 	cmpl	$(PAGE_SIZE - VEC_SIZE), %eax
 	jg	L(cross_page_boundary)
-
+L(page_cross_continue):
 	VMOVU	(%rdi), %VMM(1)
 	/* k0 has a 1 for each zero CHAR in YMM1.  */
 	VPTESTN	%VMM(1), %VMM(1), %k0
@@ -79,7 +87,7 @@  ENTRY_P2ALIGN(STRRCHR, 6)
 	test	%VGPR(rsi), %VGPR(rsi)
 	jz	L(aligned_more)
 	/* fallthrough: zero CHAR in first VEC.  */
-L(page_cross_return):
+
 	/* K1 has a 1 for each search CHAR match in VEC(1).  */
 	VPCMPEQ	%VMATCH, %VMM(1), %k1
 	KMOV	%k1, %VGPR(rax)
@@ -167,7 +175,6 @@  L(first_vec_x1_return):
 
 	.p2align 4,, 12
 L(aligned_more):
-L(page_cross_continue):
 	/* Need to keep original pointer incase VEC(1) has last match.  */
 	movq	%rdi, %rsi
 	andq	$-VEC_SIZE, %rdi
@@ -340,34 +347,54 @@  L(return_new_match_ret):
 	leaq	(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %rax
 	ret
 
-	.p2align 4,, 4
 L(cross_page_boundary):
+	/* eax contains all the page offset bits of src (rdi). `xor rdi,
+	   rax` sets pointer will all page offset bits cleared so
+	   offset of (PAGE_SIZE - VEC_SIZE) will get last aligned VEC
+	   before page cross (guaranteed to be safe to read). Doing this
+	   as opposed to `movq %rdi, %rax; andq $-VEC_SIZE, %rax` saves
+	   a bit of code size.  */
 	xorq	%rdi, %rax
-	mov	$-1, %VRDX
-	VMOVU	(PAGE_SIZE - VEC_SIZE)(%rax), %VMM(6)
-	VPTESTN	%VMM(6), %VMM(6), %k0
+	VMOVU	(PAGE_SIZE - VEC_SIZE)(%rax), %VMM(1)
+	VPTESTN	%VMM(1), %VMM(1), %k0
 	KMOV	%k0, %VRSI
 
-# ifdef USE_AS_WCSRCHR
+	/* Shift out zero CHAR matches that are before the beginning of
+	   src (rdi).  */
+# if VEC_SIZE == 64 || (defined USE_AS_WCSRCHR)
 	movl	%edi, %ecx
-	and	$(VEC_SIZE - 1), %ecx
+# endif
+# ifdef USE_AS_WCSRCHR
+	andl	$(VEC_SIZE - 1), %ecx
 	shrl	$2, %ecx
 # endif
-	shlx	%SHIFT_REG, %VRDX, %VRDX
+	SHIFT_R	(%SHIFT_REG, %VRSI)
+# if VEC_SIZE == 32 || (defined USE_AS_WCSRCHR)
+	/* For strrchr-evex512 we use SHIFT_R as shr which will set zero
+	   flag.  */
+	test	%VRSI, %VRSI
+# endif
+	jz	L(page_cross_continue)
 
+	/* Found zero CHAR so need to test for search CHAR.  */
+	VPCMPEQ	%VMATCH, %VMM(1), %k1
+	KMOV	%k1, %VRAX
+	/* Shift out search CHAR matches that are before the beginning of
+	   src (rdi).  */
+	SHIFT_R	(%SHIFT_REG, %VRAX)
+	/* Check if any search CHAR match in range.  */
+	blsmsk	%VRSI, %VRSI
+	and	%VRSI, %VRAX
+	jz	L(ret2)
+	bsr	%VRAX, %VRAX
 # ifdef USE_AS_WCSRCHR
-	kmovw	%edx, %k1
+	leaq	(%rdi, %rax, CHAR_SIZE), %rax
 # else
-	KMOV	%VRDX, %k1
+	addq	%rdi, %rax
 # endif
-
-	VPCOMPRESS %VMM(6), %VMM(1){%k1}{z}
-	/* We could technically just jmp back after the vpcompress but
-	   it doesn't save any 16-byte blocks.  */
-	shrx	%SHIFT_REG, %VRSI, %VRSI
-	test	%VRSI, %VRSI
-	jnz	L(page_cross_return)
-	jmp	L(page_cross_continue)
-	/* 1-byte from cache line.  */
+L(ret2):
+	ret
+	/* 3 bytes from cache-line for evex.  */
+	/* 0 bytes from cache-line for evex512.  */
 END(STRRCHR)
 #endif