[v2,2/3] x86: Fix overflow bug with wmemchr-sse2 and wmemchr-avx2 [BZ #27974]

Message ID 20210622181111.185897-2-goldstein.w.n@gmail.com
State Committed
Commit 645a158978f9520e74074e8c14047503be4db0f0
Headers
Series None |

Commit Message

Noah Goldstein June 22, 2021, 6:11 p.m. UTC
  This commit fixes the bug mentioned in the previous commit.

The previous implementations of wmemchr in these files relied
on n * sizeof(wchar_t) which was not guranteed by the standard.

The new overflow tests added in the previous commit now
pass (As well as all the other tests).

Signed-off-by: Noah Goldstein <goldstein.w.n@gmail.com>
---
 sysdeps/x86_64/memchr.S                | 77 +++++++++++++++++++-------
 sysdeps/x86_64/multiarch/memchr-avx2.S | 58 +++++++++++++------
 2 files changed, 98 insertions(+), 37 deletions(-)
  

Comments

H.J. Lu June 22, 2021, 9:24 p.m. UTC | #1
On Tue, Jun 22, 2021 at 11:20 AM Noah Goldstein <goldstein.w.n@gmail.com> wrote:
>
> This commit fixes the bug mentioned in the previous commit.
>
> The previous implementations of wmemchr in these files relied
> on n * sizeof(wchar_t) which was not guranteed by the standard.
>
> The new overflow tests added in the previous commit now
> pass (As well as all the other tests).
>
> Signed-off-by: Noah Goldstein <goldstein.w.n@gmail.com>
> ---
>  sysdeps/x86_64/memchr.S                | 77 +++++++++++++++++++-------
>  sysdeps/x86_64/multiarch/memchr-avx2.S | 58 +++++++++++++------
>  2 files changed, 98 insertions(+), 37 deletions(-)
>
> diff --git a/sysdeps/x86_64/memchr.S b/sysdeps/x86_64/memchr.S
> index beff2708de..3ddc4655cf 100644
> --- a/sysdeps/x86_64/memchr.S
> +++ b/sysdeps/x86_64/memchr.S
> @@ -21,9 +21,11 @@
>  #ifdef USE_AS_WMEMCHR
>  # define MEMCHR                wmemchr
>  # define PCMPEQ                pcmpeqd
> +# define CHAR_PER_VEC  4
>  #else
>  # define MEMCHR                memchr
>  # define PCMPEQ                pcmpeqb
> +# define CHAR_PER_VEC  16
>  #endif
>
>  /* fast SSE2 version with using pmaxub and 64 byte loop */
> @@ -33,15 +35,14 @@ ENTRY(MEMCHR)
>         movd    %esi, %xmm1
>         mov     %edi, %ecx
>
> +#ifdef __ILP32__
> +       /* Clear the upper 32 bits.  */
> +       movl    %edx, %edx
> +#endif
>  #ifdef USE_AS_WMEMCHR
>         test    %RDX_LP, %RDX_LP
>         jz      L(return_null)
> -       shl     $2, %RDX_LP
>  #else
> -# ifdef __ILP32__
> -       /* Clear the upper 32 bits.  */
> -       movl    %edx, %edx
> -# endif
>         punpcklbw %xmm1, %xmm1
>         test    %RDX_LP, %RDX_LP
>         jz      L(return_null)
> @@ -60,13 +61,16 @@ ENTRY(MEMCHR)
>         test    %eax, %eax
>
>         jnz     L(matches_1)
> -       sub     $16, %rdx
> +       sub     $CHAR_PER_VEC, %rdx
>         jbe     L(return_null)
>         add     $16, %rdi
>         and     $15, %ecx
>         and     $-16, %rdi
> +#ifdef USE_AS_WMEMCHR
> +       shr     $2, %ecx
> +#endif
>         add     %rcx, %rdx
> -       sub     $64, %rdx
> +       sub     $(CHAR_PER_VEC * 4), %rdx
>         jbe     L(exit_loop)
>         jmp     L(loop_prolog)
>
> @@ -77,16 +81,21 @@ L(crosscache):
>         movdqa  (%rdi), %xmm0
>
>         PCMPEQ  %xmm1, %xmm0
> -/* Check if there is a match.  */
> +       /* Check if there is a match.  */
>         pmovmskb %xmm0, %eax
> -/* Remove the leading bytes.  */
> +       /* Remove the leading bytes.  */
>         sar     %cl, %eax
>         test    %eax, %eax
>         je      L(unaligned_no_match)
> -/* Check which byte is a match.  */
> +       /* Check which byte is a match.  */
>         bsf     %eax, %eax
> -
> +#ifdef USE_AS_WMEMCHR
> +       mov     %eax, %esi
> +       shr     $2, %esi
> +       sub     %rsi, %rdx
> +#else
>         sub     %rax, %rdx
> +#endif
>         jbe     L(return_null)
>         add     %rdi, %rax
>         add     %rcx, %rax
> @@ -94,15 +103,18 @@ L(crosscache):
>
>         .p2align 4
>  L(unaligned_no_match):
> -        /* "rcx" is less than 16.  Calculate "rdx + rcx - 16" by using
> +       /* "rcx" is less than 16.  Calculate "rdx + rcx - 16" by using
>            "rdx - (16 - rcx)" instead of "(rdx + rcx) - 16" to void
>            possible addition overflow.  */
>         neg     %rcx
>         add     $16, %rcx
> +#ifdef USE_AS_WMEMCHR
> +       shr     $2, %ecx
> +#endif
>         sub     %rcx, %rdx
>         jbe     L(return_null)
>         add     $16, %rdi
> -       sub     $64, %rdx
> +       sub     $(CHAR_PER_VEC * 4), %rdx
>         jbe     L(exit_loop)
>
>         .p2align 4
> @@ -135,7 +147,7 @@ L(loop_prolog):
>         test    $0x3f, %rdi
>         jz      L(align64_loop)
>
> -       sub     $64, %rdx
> +       sub     $(CHAR_PER_VEC * 4), %rdx
>         jbe     L(exit_loop)
>
>         movdqa  (%rdi), %xmm0
> @@ -167,11 +179,14 @@ L(loop_prolog):
>         mov     %rdi, %rcx
>         and     $-64, %rdi
>         and     $63, %ecx
> +#ifdef USE_AS_WMEMCHR
> +       shr     $2, %ecx
> +#endif
>         add     %rcx, %rdx
>
>         .p2align 4
>  L(align64_loop):
> -       sub     $64, %rdx
> +       sub     $(CHAR_PER_VEC * 4), %rdx
>         jbe     L(exit_loop)
>         movdqa  (%rdi), %xmm0
>         movdqa  16(%rdi), %xmm2
> @@ -218,7 +233,7 @@ L(align64_loop):
>
>         .p2align 4
>  L(exit_loop):
> -       add     $32, %edx
> +       add     $(CHAR_PER_VEC * 2), %edx
>         jle     L(exit_loop_32)
>
>         movdqa  (%rdi), %xmm0
> @@ -238,7 +253,7 @@ L(exit_loop):
>         pmovmskb %xmm3, %eax
>         test    %eax, %eax
>         jnz     L(matches32_1)
> -       sub     $16, %edx
> +       sub     $CHAR_PER_VEC, %edx
>         jle     L(return_null)
>
>         PCMPEQ  48(%rdi), %xmm1
> @@ -250,13 +265,13 @@ L(exit_loop):
>
>         .p2align 4
>  L(exit_loop_32):
> -       add     $32, %edx
> +       add     $(CHAR_PER_VEC * 2), %edx
>         movdqa  (%rdi), %xmm0
>         PCMPEQ  %xmm1, %xmm0
>         pmovmskb %xmm0, %eax
>         test    %eax, %eax
>         jnz     L(matches_1)
> -       sub     $16, %edx
> +       sub     $CHAR_PER_VEC, %edx
>         jbe     L(return_null)
>
>         PCMPEQ  16(%rdi), %xmm1
> @@ -293,7 +308,13 @@ L(matches32):
>         .p2align 4
>  L(matches_1):
>         bsf     %eax, %eax
> +#ifdef USE_AS_WMEMCHR
> +       mov     %eax, %esi
> +       shr     $2, %esi
> +       sub     %rsi, %rdx
> +#else
>         sub     %rax, %rdx
> +#endif
>         jbe     L(return_null)
>         add     %rdi, %rax
>         ret
> @@ -301,7 +322,13 @@ L(matches_1):
>         .p2align 4
>  L(matches16_1):
>         bsf     %eax, %eax
> +#ifdef USE_AS_WMEMCHR
> +       mov     %eax, %esi
> +       shr     $2, %esi
> +       sub     %rsi, %rdx
> +#else
>         sub     %rax, %rdx
> +#endif
>         jbe     L(return_null)
>         lea     16(%rdi, %rax), %rax
>         ret
> @@ -309,7 +336,13 @@ L(matches16_1):
>         .p2align 4
>  L(matches32_1):
>         bsf     %eax, %eax
> +#ifdef USE_AS_WMEMCHR
> +       mov     %eax, %esi
> +       shr     $2, %esi
> +       sub     %rsi, %rdx
> +#else
>         sub     %rax, %rdx
> +#endif
>         jbe     L(return_null)
>         lea     32(%rdi, %rax), %rax
>         ret
> @@ -317,7 +350,13 @@ L(matches32_1):
>         .p2align 4
>  L(matches48_1):
>         bsf     %eax, %eax
> +#ifdef USE_AS_WMEMCHR
> +       mov     %eax, %esi
> +       shr     $2, %esi
> +       sub     %rsi, %rdx
> +#else
>         sub     %rax, %rdx
> +#endif
>         jbe     L(return_null)
>         lea     48(%rdi, %rax), %rax
>         ret
> diff --git a/sysdeps/x86_64/multiarch/memchr-avx2.S b/sysdeps/x86_64/multiarch/memchr-avx2.S
> index 0d8758e3e7..afdb956502 100644
> --- a/sysdeps/x86_64/multiarch/memchr-avx2.S
> +++ b/sysdeps/x86_64/multiarch/memchr-avx2.S
> @@ -54,21 +54,19 @@
>
>  # define VEC_SIZE 32
>  # define PAGE_SIZE 4096
> +# define CHAR_PER_VEC  (VEC_SIZE / CHAR_SIZE)
>
>         .section SECTION(.text),"ax",@progbits
>  ENTRY (MEMCHR)
>  # ifndef USE_AS_RAWMEMCHR
>         /* Check for zero length.  */
> -       test    %RDX_LP, %RDX_LP
> -       jz      L(null)
> -# endif
> -# ifdef USE_AS_WMEMCHR
> -       shl     $2, %RDX_LP
> -# else
>  #  ifdef __ILP32__
> -       /* Clear the upper 32 bits.  */
> -       movl    %edx, %edx
> +       /* Clear upper bits.  */
> +       and     %RDX_LP, %RDX_LP
> +#  else
> +       test    %RDX_LP, %RDX_LP
>  #  endif
> +       jz      L(null)
>  # endif
>         /* Broadcast CHAR to YMMMATCH.  */
>         vmovd   %esi, %xmm0
> @@ -84,7 +82,7 @@ ENTRY (MEMCHR)
>         vpmovmskb %ymm1, %eax
>  # ifndef USE_AS_RAWMEMCHR
>         /* If length < CHAR_PER_VEC handle special.  */
> -       cmpq    $VEC_SIZE, %rdx
> +       cmpq    $CHAR_PER_VEC, %rdx
>         jbe     L(first_vec_x0)
>  # endif
>         testl   %eax, %eax
> @@ -98,6 +96,10 @@ ENTRY (MEMCHR)
>  L(first_vec_x0):
>         /* Check if first match was before length.  */
>         tzcntl  %eax, %eax
> +#  ifdef USE_AS_WMEMCHR
> +       /* NB: Multiply length by 4 to get byte count.  */
> +       sall    $2, %edx
> +#  endif
>         xorl    %ecx, %ecx
>         cmpl    %eax, %edx
>         leaq    (%rdi, %rax), %rax
> @@ -110,12 +112,12 @@ L(null):
>  # endif
>         .p2align 4
>  L(cross_page_boundary):
> -       /* Save pointer before aligning as its original value is necessary
> -          for computer return address if byte is found or adjusting length
> -          if it is not and this is memchr.  */
> +       /* Save pointer before aligning as its original value is
> +          necessary for computer return address if byte is found or
> +          adjusting length if it is not and this is memchr.  */
>         movq    %rdi, %rcx
> -       /* Align data to VEC_SIZE - 1. ALGN_PTR_REG is rcx for memchr and
> -          rdi for rawmemchr.  */
> +       /* Align data to VEC_SIZE - 1. ALGN_PTR_REG is rcx for memchr
> +          and rdi for rawmemchr.  */
>         orq     $(VEC_SIZE - 1), %ALGN_PTR_REG
>         VPCMPEQ -(VEC_SIZE - 1)(%ALGN_PTR_REG), %ymm0, %ymm1
>         vpmovmskb %ymm1, %eax
> @@ -124,6 +126,10 @@ L(cross_page_boundary):
>            match).  */
>         leaq    1(%ALGN_PTR_REG), %rsi
>         subq    %RRAW_PTR_REG, %rsi
> +#  ifdef USE_AS_WMEMCHR
> +       /* NB: Divide bytes by 4 to get wchar_t count.  */
> +       shrl    $2, %esi
> +#  endif
>  # endif
>         /* Remove the leading bytes.  */
>         sarxl   %ERAW_PTR_REG, %eax, %eax
> @@ -181,6 +187,10 @@ L(cross_page_continue):
>         orq     $(VEC_SIZE - 1), %rdi
>         /* esi is for adjusting length to see if near the end.  */
>         leal    (VEC_SIZE * 4 + 1)(%rdi, %rcx), %esi
> +#  ifdef USE_AS_WMEMCHR
> +       /* NB: Divide bytes by 4 to get the wchar_t count.  */
> +       sarl    $2, %esi
> +#  endif
>  # else
>         orq     $(VEC_SIZE - 1), %rdi
>  L(cross_page_continue):
> @@ -213,7 +223,7 @@ L(cross_page_continue):
>
>  # ifndef USE_AS_RAWMEMCHR
>         /* Check if at last VEC_SIZE * 4 length.  */
> -       subq    $(VEC_SIZE * 4), %rdx
> +       subq    $(CHAR_PER_VEC * 4), %rdx
>         jbe     L(last_4x_vec_or_less_cmpeq)
>         /* Align data to VEC_SIZE * 4 - 1 for the loop and readjust
>            length.  */
> @@ -221,6 +231,10 @@ L(cross_page_continue):
>         movl    %edi, %ecx
>         orq     $(VEC_SIZE * 4 - 1), %rdi
>         andl    $(VEC_SIZE * 4 - 1), %ecx
> +#  ifdef USE_AS_WMEMCHR
> +       /* NB: Divide bytes by 4 to get the wchar_t count.  */
> +       sarl    $2, %ecx
> +#  endif
>         addq    %rcx, %rdx
>  # else
>         /* Align data to VEC_SIZE * 4 - 1 for loop.  */
> @@ -250,15 +264,19 @@ L(loop_4x_vec):
>
>         subq    $-(VEC_SIZE * 4), %rdi
>
> -       subq    $(VEC_SIZE * 4), %rdx
> +       subq    $(CHAR_PER_VEC * 4), %rdx
>         ja      L(loop_4x_vec)
>
> -       /* Fall through into less than 4 remaining vectors of length case.
> -        */
> +       /* Fall through into less than 4 remaining vectors of length
> +          case.  */
>         VPCMPEQ (VEC_SIZE * 0 + 1)(%rdi), %ymm0, %ymm1
>         vpmovmskb %ymm1, %eax
>         .p2align 4
>  L(last_4x_vec_or_less):
> +#  ifdef USE_AS_WMEMCHR
> +       /* NB: Multiply length by 4 to get byte count.  */
> +       sall    $2, %edx
> +#  endif
>         /* Check if first VEC contained match.  */
>         testl   %eax, %eax
>         jnz     L(first_vec_x1_check)
> @@ -355,6 +373,10 @@ L(last_vec_x2_return):
>  L(last_4x_vec_or_less_cmpeq):
>         VPCMPEQ (VEC_SIZE * 4 + 1)(%rdi), %ymm0, %ymm1
>         vpmovmskb %ymm1, %eax
> +#  ifdef USE_AS_WMEMCHR
> +       /* NB: Multiply length by 4 to get byte count.  */
> +       sall    $2, %edx
> +#  endif
>         subq    $-(VEC_SIZE * 4), %rdi
>         /* Check first VEC regardless.  */
>         testl   %eax, %eax
> --
> 2.25.1
>

LGTM.

Reviewed-by: H.J. Lu <hjl.tools@gmail.com>

Thanks.
  

Patch

diff --git a/sysdeps/x86_64/memchr.S b/sysdeps/x86_64/memchr.S
index beff2708de..3ddc4655cf 100644
--- a/sysdeps/x86_64/memchr.S
+++ b/sysdeps/x86_64/memchr.S
@@ -21,9 +21,11 @@ 
 #ifdef USE_AS_WMEMCHR
 # define MEMCHR		wmemchr
 # define PCMPEQ		pcmpeqd
+# define CHAR_PER_VEC	4
 #else
 # define MEMCHR		memchr
 # define PCMPEQ		pcmpeqb
+# define CHAR_PER_VEC	16
 #endif
 
 /* fast SSE2 version with using pmaxub and 64 byte loop */
@@ -33,15 +35,14 @@  ENTRY(MEMCHR)
 	movd	%esi, %xmm1
 	mov	%edi, %ecx
 
+#ifdef __ILP32__
+	/* Clear the upper 32 bits.  */
+	movl	%edx, %edx
+#endif
 #ifdef USE_AS_WMEMCHR
 	test	%RDX_LP, %RDX_LP
 	jz	L(return_null)
-	shl	$2, %RDX_LP
 #else
-# ifdef __ILP32__
-	/* Clear the upper 32 bits.  */
-	movl	%edx, %edx
-# endif
 	punpcklbw %xmm1, %xmm1
 	test	%RDX_LP, %RDX_LP
 	jz	L(return_null)
@@ -60,13 +61,16 @@  ENTRY(MEMCHR)
 	test	%eax, %eax
 
 	jnz	L(matches_1)
-	sub	$16, %rdx
+	sub	$CHAR_PER_VEC, %rdx
 	jbe	L(return_null)
 	add	$16, %rdi
 	and	$15, %ecx
 	and	$-16, %rdi
+#ifdef USE_AS_WMEMCHR
+	shr	$2, %ecx
+#endif
 	add	%rcx, %rdx
-	sub	$64, %rdx
+	sub	$(CHAR_PER_VEC * 4), %rdx
 	jbe	L(exit_loop)
 	jmp	L(loop_prolog)
 
@@ -77,16 +81,21 @@  L(crosscache):
 	movdqa	(%rdi), %xmm0
 
 	PCMPEQ	%xmm1, %xmm0
-/* Check if there is a match.  */
+	/* Check if there is a match.  */
 	pmovmskb %xmm0, %eax
-/* Remove the leading bytes.  */
+	/* Remove the leading bytes.  */
 	sar	%cl, %eax
 	test	%eax, %eax
 	je	L(unaligned_no_match)
-/* Check which byte is a match.  */
+	/* Check which byte is a match.  */
 	bsf	%eax, %eax
-
+#ifdef USE_AS_WMEMCHR
+	mov	%eax, %esi
+	shr	$2, %esi
+	sub	%rsi, %rdx
+#else
 	sub	%rax, %rdx
+#endif
 	jbe	L(return_null)
 	add	%rdi, %rax
 	add	%rcx, %rax
@@ -94,15 +103,18 @@  L(crosscache):
 
 	.p2align 4
 L(unaligned_no_match):
-        /* "rcx" is less than 16.  Calculate "rdx + rcx - 16" by using
+	/* "rcx" is less than 16.  Calculate "rdx + rcx - 16" by using
 	   "rdx - (16 - rcx)" instead of "(rdx + rcx) - 16" to void
 	   possible addition overflow.  */
 	neg	%rcx
 	add	$16, %rcx
+#ifdef USE_AS_WMEMCHR
+	shr	$2, %ecx
+#endif
 	sub	%rcx, %rdx
 	jbe	L(return_null)
 	add	$16, %rdi
-	sub	$64, %rdx
+	sub	$(CHAR_PER_VEC * 4), %rdx
 	jbe	L(exit_loop)
 
 	.p2align 4
@@ -135,7 +147,7 @@  L(loop_prolog):
 	test	$0x3f, %rdi
 	jz	L(align64_loop)
 
-	sub	$64, %rdx
+	sub	$(CHAR_PER_VEC * 4), %rdx
 	jbe	L(exit_loop)
 
 	movdqa	(%rdi), %xmm0
@@ -167,11 +179,14 @@  L(loop_prolog):
 	mov	%rdi, %rcx
 	and	$-64, %rdi
 	and	$63, %ecx
+#ifdef USE_AS_WMEMCHR
+	shr	$2, %ecx
+#endif
 	add	%rcx, %rdx
 
 	.p2align 4
 L(align64_loop):
-	sub	$64, %rdx
+	sub	$(CHAR_PER_VEC * 4), %rdx
 	jbe	L(exit_loop)
 	movdqa	(%rdi), %xmm0
 	movdqa	16(%rdi), %xmm2
@@ -218,7 +233,7 @@  L(align64_loop):
 
 	.p2align 4
 L(exit_loop):
-	add	$32, %edx
+	add	$(CHAR_PER_VEC * 2), %edx
 	jle	L(exit_loop_32)
 
 	movdqa	(%rdi), %xmm0
@@ -238,7 +253,7 @@  L(exit_loop):
 	pmovmskb %xmm3, %eax
 	test	%eax, %eax
 	jnz	L(matches32_1)
-	sub	$16, %edx
+	sub	$CHAR_PER_VEC, %edx
 	jle	L(return_null)
 
 	PCMPEQ	48(%rdi), %xmm1
@@ -250,13 +265,13 @@  L(exit_loop):
 
 	.p2align 4
 L(exit_loop_32):
-	add	$32, %edx
+	add	$(CHAR_PER_VEC * 2), %edx
 	movdqa	(%rdi), %xmm0
 	PCMPEQ	%xmm1, %xmm0
 	pmovmskb %xmm0, %eax
 	test	%eax, %eax
 	jnz	L(matches_1)
-	sub	$16, %edx
+	sub	$CHAR_PER_VEC, %edx
 	jbe	L(return_null)
 
 	PCMPEQ	16(%rdi), %xmm1
@@ -293,7 +308,13 @@  L(matches32):
 	.p2align 4
 L(matches_1):
 	bsf	%eax, %eax
+#ifdef USE_AS_WMEMCHR
+	mov	%eax, %esi
+	shr	$2, %esi
+	sub	%rsi, %rdx
+#else
 	sub	%rax, %rdx
+#endif
 	jbe	L(return_null)
 	add	%rdi, %rax
 	ret
@@ -301,7 +322,13 @@  L(matches_1):
 	.p2align 4
 L(matches16_1):
 	bsf	%eax, %eax
+#ifdef USE_AS_WMEMCHR
+	mov	%eax, %esi
+	shr	$2, %esi
+	sub	%rsi, %rdx
+#else
 	sub	%rax, %rdx
+#endif
 	jbe	L(return_null)
 	lea	16(%rdi, %rax), %rax
 	ret
@@ -309,7 +336,13 @@  L(matches16_1):
 	.p2align 4
 L(matches32_1):
 	bsf	%eax, %eax
+#ifdef USE_AS_WMEMCHR
+	mov	%eax, %esi
+	shr	$2, %esi
+	sub	%rsi, %rdx
+#else
 	sub	%rax, %rdx
+#endif
 	jbe	L(return_null)
 	lea	32(%rdi, %rax), %rax
 	ret
@@ -317,7 +350,13 @@  L(matches32_1):
 	.p2align 4
 L(matches48_1):
 	bsf	%eax, %eax
+#ifdef USE_AS_WMEMCHR
+	mov	%eax, %esi
+	shr	$2, %esi
+	sub	%rsi, %rdx
+#else
 	sub	%rax, %rdx
+#endif
 	jbe	L(return_null)
 	lea	48(%rdi, %rax), %rax
 	ret
diff --git a/sysdeps/x86_64/multiarch/memchr-avx2.S b/sysdeps/x86_64/multiarch/memchr-avx2.S
index 0d8758e3e7..afdb956502 100644
--- a/sysdeps/x86_64/multiarch/memchr-avx2.S
+++ b/sysdeps/x86_64/multiarch/memchr-avx2.S
@@ -54,21 +54,19 @@ 
 
 # define VEC_SIZE 32
 # define PAGE_SIZE 4096
+# define CHAR_PER_VEC	(VEC_SIZE / CHAR_SIZE)
 
 	.section SECTION(.text),"ax",@progbits
 ENTRY (MEMCHR)
 # ifndef USE_AS_RAWMEMCHR
 	/* Check for zero length.  */
-	test	%RDX_LP, %RDX_LP
-	jz	L(null)
-# endif
-# ifdef USE_AS_WMEMCHR
-	shl	$2, %RDX_LP
-# else
 #  ifdef __ILP32__
-	/* Clear the upper 32 bits.  */
-	movl	%edx, %edx
+	/* Clear upper bits.  */
+	and	%RDX_LP, %RDX_LP
+#  else
+	test	%RDX_LP, %RDX_LP
 #  endif
+	jz	L(null)
 # endif
 	/* Broadcast CHAR to YMMMATCH.  */
 	vmovd	%esi, %xmm0
@@ -84,7 +82,7 @@  ENTRY (MEMCHR)
 	vpmovmskb %ymm1, %eax
 # ifndef USE_AS_RAWMEMCHR
 	/* If length < CHAR_PER_VEC handle special.  */
-	cmpq	$VEC_SIZE, %rdx
+	cmpq	$CHAR_PER_VEC, %rdx
 	jbe	L(first_vec_x0)
 # endif
 	testl	%eax, %eax
@@ -98,6 +96,10 @@  ENTRY (MEMCHR)
 L(first_vec_x0):
 	/* Check if first match was before length.  */
 	tzcntl	%eax, %eax
+#  ifdef USE_AS_WMEMCHR
+	/* NB: Multiply length by 4 to get byte count.  */
+	sall	$2, %edx
+#  endif
 	xorl	%ecx, %ecx
 	cmpl	%eax, %edx
 	leaq	(%rdi, %rax), %rax
@@ -110,12 +112,12 @@  L(null):
 # endif
 	.p2align 4
 L(cross_page_boundary):
-	/* Save pointer before aligning as its original value is necessary
-	   for computer return address if byte is found or adjusting length
-	   if it is not and this is memchr.  */
+	/* Save pointer before aligning as its original value is
+	   necessary for computer return address if byte is found or
+	   adjusting length if it is not and this is memchr.  */
 	movq	%rdi, %rcx
-	/* Align data to VEC_SIZE - 1. ALGN_PTR_REG is rcx for memchr and
-	   rdi for rawmemchr.  */
+	/* Align data to VEC_SIZE - 1. ALGN_PTR_REG is rcx for memchr
+	   and rdi for rawmemchr.  */
 	orq	$(VEC_SIZE - 1), %ALGN_PTR_REG
 	VPCMPEQ	-(VEC_SIZE - 1)(%ALGN_PTR_REG), %ymm0, %ymm1
 	vpmovmskb %ymm1, %eax
@@ -124,6 +126,10 @@  L(cross_page_boundary):
 	   match).  */
 	leaq	1(%ALGN_PTR_REG), %rsi
 	subq	%RRAW_PTR_REG, %rsi
+#  ifdef USE_AS_WMEMCHR
+	/* NB: Divide bytes by 4 to get wchar_t count.  */
+	shrl	$2, %esi
+#  endif
 # endif
 	/* Remove the leading bytes.  */
 	sarxl	%ERAW_PTR_REG, %eax, %eax
@@ -181,6 +187,10 @@  L(cross_page_continue):
 	orq	$(VEC_SIZE - 1), %rdi
 	/* esi is for adjusting length to see if near the end.  */
 	leal	(VEC_SIZE * 4 + 1)(%rdi, %rcx), %esi
+#  ifdef USE_AS_WMEMCHR
+	/* NB: Divide bytes by 4 to get the wchar_t count.  */
+	sarl	$2, %esi
+#  endif
 # else
 	orq	$(VEC_SIZE - 1), %rdi
 L(cross_page_continue):
@@ -213,7 +223,7 @@  L(cross_page_continue):
 
 # ifndef USE_AS_RAWMEMCHR
 	/* Check if at last VEC_SIZE * 4 length.  */
-	subq	$(VEC_SIZE * 4), %rdx
+	subq	$(CHAR_PER_VEC * 4), %rdx
 	jbe	L(last_4x_vec_or_less_cmpeq)
 	/* Align data to VEC_SIZE * 4 - 1 for the loop and readjust
 	   length.  */
@@ -221,6 +231,10 @@  L(cross_page_continue):
 	movl	%edi, %ecx
 	orq	$(VEC_SIZE * 4 - 1), %rdi
 	andl	$(VEC_SIZE * 4 - 1), %ecx
+#  ifdef USE_AS_WMEMCHR
+	/* NB: Divide bytes by 4 to get the wchar_t count.  */
+	sarl	$2, %ecx
+#  endif
 	addq	%rcx, %rdx
 # else
 	/* Align data to VEC_SIZE * 4 - 1 for loop.  */
@@ -250,15 +264,19 @@  L(loop_4x_vec):
 
 	subq	$-(VEC_SIZE * 4), %rdi
 
-	subq	$(VEC_SIZE * 4), %rdx
+	subq	$(CHAR_PER_VEC * 4), %rdx
 	ja	L(loop_4x_vec)
 
-	/* Fall through into less than 4 remaining vectors of length case.
-	 */
+	/* Fall through into less than 4 remaining vectors of length
+	   case.  */
 	VPCMPEQ	(VEC_SIZE * 0 + 1)(%rdi), %ymm0, %ymm1
 	vpmovmskb %ymm1, %eax
 	.p2align 4
 L(last_4x_vec_or_less):
+#  ifdef USE_AS_WMEMCHR
+	/* NB: Multiply length by 4 to get byte count.  */
+	sall	$2, %edx
+#  endif
 	/* Check if first VEC contained match.  */
 	testl	%eax, %eax
 	jnz	L(first_vec_x1_check)
@@ -355,6 +373,10 @@  L(last_vec_x2_return):
 L(last_4x_vec_or_less_cmpeq):
 	VPCMPEQ	(VEC_SIZE * 4 + 1)(%rdi), %ymm0, %ymm1
 	vpmovmskb %ymm1, %eax
+#  ifdef USE_AS_WMEMCHR
+	/* NB: Multiply length by 4 to get byte count.  */
+	sall	$2, %edx
+#  endif
 	subq	$-(VEC_SIZE * 4), %rdi
 	/* Check first VEC regardless.  */
 	testl	%eax, %eax