[v1,2/2] x86: Optimize L(less_vec) case in memcmpeq-evex.S

Message ID 20211225032257.2887327-2-goldstein.w.n@gmail.com
State Committed
Commit cca457f9c51a90cf82cae75432ed3de20942519c
Headers
Series [v1,1/2] x86: Optimize L(less_vec) case in memcmp-evex-movbe.S |

Checks

Context Check Description
dj/TryBot-apply_patch success Patch applied to master at the time it was sent

Commit Message

Noah Goldstein Dec. 25, 2021, 3:22 a.m. UTC
  No bug.
Optimizations are twofold.

1) Replace page cross and 0/1 checks with masked load instructions in
   L(less_vec). In applications this reduces branch-misses in the
   hot [0, 32] case.
2) Change controlflow so that L(less_vec) case gets the fall through.

Change 2) helps copies in the [0, 32] size range but comes at the cost
of copies in the [33, 64] size range.  From profiles of GCC and
Python3, 94%+ and 99%+ of calls are in the [0, 32] range so this
appears to the the right tradeoff.
---
 sysdeps/x86_64/multiarch/memcmpeq-evex.S | 170 ++++++-----------------
 1 file changed, 43 insertions(+), 127 deletions(-)
  

Comments

H.J. Lu Dec. 26, 2021, 5:05 p.m. UTC | #1
On Fri, Dec 24, 2021 at 7:23 PM Noah Goldstein <goldstein.w.n@gmail.com> wrote:
>
> No bug.
> Optimizations are twofold.
>
> 1) Replace page cross and 0/1 checks with masked load instructions in
>    L(less_vec). In applications this reduces branch-misses in the
>    hot [0, 32] case.
> 2) Change controlflow so that L(less_vec) case gets the fall through.
>
> Change 2) helps copies in the [0, 32] size range but comes at the cost
> of copies in the [33, 64] size range.  From profiles of GCC and
> Python3, 94%+ and 99%+ of calls are in the [0, 32] range so this
> appears to the the right tradeoff.
> ---
>  sysdeps/x86_64/multiarch/memcmpeq-evex.S | 170 ++++++-----------------
>  1 file changed, 43 insertions(+), 127 deletions(-)
>
> diff --git a/sysdeps/x86_64/multiarch/memcmpeq-evex.S b/sysdeps/x86_64/multiarch/memcmpeq-evex.S
> index f27e732036..b5e1edbdff 100644
> --- a/sysdeps/x86_64/multiarch/memcmpeq-evex.S
> +++ b/sysdeps/x86_64/multiarch/memcmpeq-evex.S
> @@ -39,6 +39,7 @@
>  #  define MEMCMPEQ     __memcmpeq_evex
>  # endif
>
> +# define VMOVU_MASK    vmovdqu8
>  # define VMOVU vmovdqu64
>  # define VPCMP vpcmpub
>  # define VPTEST        vptestmb
> @@ -62,12 +63,39 @@ ENTRY_P2ALIGN (MEMCMPEQ, 6)
>         movl    %edx, %edx
>  # endif
>         cmp     $VEC_SIZE, %RDX_LP
> -       jb      L(less_vec)
> +       /* Fall through for [0, VEC_SIZE] as its the hottest.  */
> +       ja      L(more_1x_vec)
> +
> +       /* Create mask of bytes that are guranteed to be valid because
> +          of length (edx). Using masked movs allows us to skip checks for
> +          page crosses/zero size.  */
> +       movl    $-1, %ecx
> +       bzhil   %edx, %ecx, %ecx
> +       kmovd   %ecx, %k2
> +
> +       /* Use masked loads as VEC_SIZE could page cross where length
> +          (edx) would not.  */
> +       VMOVU_MASK (%rsi), %YMM2{%k2}
> +       VPCMP   $4,(%rdi), %YMM2, %k1{%k2}
> +       kmovd   %k1, %eax
> +       ret
>
> -       /* From VEC to 2 * VEC.  No branch when size == VEC_SIZE.  */
> +
> +L(last_1x_vec):
> +       VMOVU   -(VEC_SIZE * 1)(%rsi, %rdx), %YMM1
> +       VPCMP   $4, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %k1
> +       kmovd   %k1, %eax
> +L(return_neq0):
> +       ret
> +
> +
> +
> +       .p2align 4
> +L(more_1x_vec):
> +       /* From VEC + 1 to 2 * VEC.  */
>         VMOVU   (%rsi), %YMM1
>         /* Use compare not equals to directly check for mismatch.  */
> -       VPCMP   $4, (%rdi), %YMM1, %k1
> +       VPCMP   $4,(%rdi), %YMM1, %k1
>         kmovd   %k1, %eax
>         testl   %eax, %eax
>         jnz     L(return_neq0)
> @@ -88,13 +116,13 @@ ENTRY_P2ALIGN (MEMCMPEQ, 6)
>
>         /* Check third and fourth VEC no matter what.  */
>         VMOVU   (VEC_SIZE * 2)(%rsi), %YMM3
> -       VPCMP   $4, (VEC_SIZE * 2)(%rdi), %YMM3, %k1
> +       VPCMP   $4,(VEC_SIZE * 2)(%rdi), %YMM3, %k1
>         kmovd   %k1, %eax
>         testl   %eax, %eax
>         jnz     L(return_neq0)
>
>         VMOVU   (VEC_SIZE * 3)(%rsi), %YMM4
> -       VPCMP   $4, (VEC_SIZE * 3)(%rdi), %YMM4, %k1
> +       VPCMP   $4,(VEC_SIZE * 3)(%rdi), %YMM4, %k1
>         kmovd   %k1, %eax
>         testl   %eax, %eax
>         jnz     L(return_neq0)
> @@ -132,66 +160,6 @@ ENTRY_P2ALIGN (MEMCMPEQ, 6)
>         /* Compare YMM4 with 0. If any 1s s1 and s2 don't match.  */
>         VPTEST  %YMM4, %YMM4, %k1
>         kmovd   %k1, %eax
> -L(return_neq0):
> -       ret
> -
> -       /* Fits in padding needed to .p2align 5 L(less_vec).  */
> -L(last_1x_vec):
> -       VMOVU   -(VEC_SIZE * 1)(%rsi, %rdx), %YMM1
> -       VPCMP   $4, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %k1
> -       kmovd   %k1, %eax
> -       ret
> -
> -       /* NB: p2align 5 here will ensure the L(loop_4x_vec) is also 32
> -          byte aligned.  */
> -       .p2align 5
> -L(less_vec):
> -       /* Check if one or less char. This is necessary for size = 0 but
> -          is also faster for size = 1.  */
> -       cmpl    $1, %edx
> -       jbe     L(one_or_less)
> -
> -       /* Check if loading one VEC from either s1 or s2 could cause a
> -          page cross. This can have false positives but is by far the
> -          fastest method.  */
> -       movl    %edi, %eax
> -       orl     %esi, %eax
> -       andl    $(PAGE_SIZE - 1), %eax
> -       cmpl    $(PAGE_SIZE - VEC_SIZE), %eax
> -       jg      L(page_cross_less_vec)
> -
> -       /* No page cross possible.  */
> -       VMOVU   (%rsi), %YMM2
> -       VPCMP   $4, (%rdi), %YMM2, %k1
> -       kmovd   %k1, %eax
> -       /* Result will be zero if s1 and s2 match. Otherwise first set
> -          bit will be first mismatch.  */
> -       bzhil   %edx, %eax, %eax
> -       ret
> -
> -       /* Relatively cold but placing close to L(less_vec) for 2 byte
> -          jump encoding.  */
> -       .p2align 4
> -L(one_or_less):
> -       jb      L(zero)
> -       movzbl  (%rsi), %ecx
> -       movzbl  (%rdi), %eax
> -       subl    %ecx, %eax
> -       /* No ymm register was touched.  */
> -       ret
> -       /* Within the same 16 byte block is L(one_or_less).  */
> -L(zero):
> -       xorl    %eax, %eax
> -       ret
> -
> -       .p2align 4
> -L(last_2x_vec):
> -       VMOVU   -(VEC_SIZE * 2)(%rsi, %rdx), %YMM1
> -       vpxorq  -(VEC_SIZE * 2)(%rdi, %rdx), %YMM1, %YMM1
> -       VMOVU   -(VEC_SIZE * 1)(%rsi, %rdx), %YMM2
> -       vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %YMM2
> -       VPTEST  %YMM2, %YMM2, %k1
> -       kmovd   %k1, %eax
>         ret
>
>         .p2align 4
> @@ -211,7 +179,7 @@ L(loop_4x_vec):
>         vpxorq  (%rdi), %YMM1, %YMM1
>
>         VMOVU   VEC_SIZE(%rsi, %rdi), %YMM2
> -       vpternlogd $0xde, (VEC_SIZE)(%rdi), %YMM1, %YMM2
> +       vpternlogd $0xde,(VEC_SIZE)(%rdi), %YMM1, %YMM2
>
>         VMOVU   (VEC_SIZE * 2)(%rsi, %rdi), %YMM3
>         vpxorq  (VEC_SIZE * 2)(%rdi), %YMM3, %YMM3
> @@ -238,7 +206,7 @@ L(loop_4x_vec):
>         VMOVU   (VEC_SIZE * 2)(%rsi, %rdx), %YMM3
>         /* Ternary logic to xor (VEC_SIZE * 2)(%rdx) with YMM3 while
>            oring with YMM4. Result is stored in YMM4.  */
> -       vpternlogd $0xf6, (VEC_SIZE * 2)(%rdx), %YMM3, %YMM4
> +       vpternlogd $0xf6,(VEC_SIZE * 2)(%rdx), %YMM3, %YMM4
>         cmpl    $(VEC_SIZE * 2), %edi
>         jae     L(8x_last_2x_vec)
>
> @@ -256,68 +224,16 @@ L(8x_last_2x_vec):
>  L(return_neq2):
>         ret
>
> -       /* Relatively cold case as page cross are unexpected.  */
> -       .p2align 4
> -L(page_cross_less_vec):
> -       cmpl    $16, %edx
> -       jae     L(between_16_31)
> -       cmpl    $8, %edx
> -       ja      L(between_9_15)
> -       cmpl    $4, %edx
> -       jb      L(between_2_3)
> -       /* From 4 to 8 bytes.  No branch when size == 4.  */
> -       movl    (%rdi), %eax
> -       subl    (%rsi), %eax
> -       movl    -4(%rdi, %rdx), %ecx
> -       movl    -4(%rsi, %rdx), %edi
> -       subl    %edi, %ecx
> -       orl     %ecx, %eax
> -       ret
> -
> -       .p2align 4,, 8
> -L(between_16_31):
> -       /* From 16 to 31 bytes.  No branch when size == 16.  */
> -
> -       /* Safe to use xmm[0, 15] as no vzeroupper is needed so RTM safe.
> -        */
> -       vmovdqu (%rsi), %xmm1
> -       vpcmpeqb (%rdi), %xmm1, %xmm1
> -       vmovdqu -16(%rsi, %rdx), %xmm2
> -       vpcmpeqb -16(%rdi, %rdx), %xmm2, %xmm2
> -       vpand   %xmm1, %xmm2, %xmm2
> -       vpmovmskb %xmm2, %eax
> -       notw    %ax
> -       /* No ymm register was touched.  */
> -       ret
> -
>         .p2align 4,, 8
> -L(between_9_15):
> -       /* From 9 to 15 bytes.  */
> -       movq    (%rdi), %rax
> -       subq    (%rsi), %rax
> -       movq    -8(%rdi, %rdx), %rcx
> -       movq    -8(%rsi, %rdx), %rdi
> -       subq    %rdi, %rcx
> -       orq     %rcx, %rax
> -       /* edx is guranteed to be a non-zero int.  */
> -       cmovnz  %edx, %eax
> -       ret
> -
> -       /* Don't align. This is cold and aligning here will cause code
> -          to spill into next cache line.  */
> -L(between_2_3):
> -       /* From 2 to 3 bytes.  No branch when size == 2.  */
> -       movzwl  (%rdi), %eax
> -       movzwl  (%rsi), %ecx
> -       subl    %ecx, %eax
> -       movzbl  -1(%rdi, %rdx), %ecx
> -       /* All machines that support evex will insert a "merging uop"
> -          avoiding any serious partial register stalls.  */
> -       subb    -1(%rsi, %rdx), %cl
> -       orl     %ecx, %eax
> -       /* No ymm register was touched.  */
> +L(last_2x_vec):
> +       VMOVU   -(VEC_SIZE * 2)(%rsi, %rdx), %YMM1
> +       vpxorq  -(VEC_SIZE * 2)(%rdi, %rdx), %YMM1, %YMM1
> +       VMOVU   -(VEC_SIZE * 1)(%rsi, %rdx), %YMM2
> +       vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %YMM2
> +       VPTEST  %YMM2, %YMM2, %k1
> +       kmovd   %k1, %eax
>         ret
>
> -    /* 4 Bytes from next cache line. */
> +    /* 1 Bytes from next cache line. */
>  END (MEMCMPEQ)
>  #endif
> --
> 2.25.1
>

LGTM.

Reviewed-by: H.J. Lu <hjl.tools@gmail.com>

Thanks.
  

Patch

diff --git a/sysdeps/x86_64/multiarch/memcmpeq-evex.S b/sysdeps/x86_64/multiarch/memcmpeq-evex.S
index f27e732036..b5e1edbdff 100644
--- a/sysdeps/x86_64/multiarch/memcmpeq-evex.S
+++ b/sysdeps/x86_64/multiarch/memcmpeq-evex.S
@@ -39,6 +39,7 @@ 
 #  define MEMCMPEQ	__memcmpeq_evex
 # endif
 
+# define VMOVU_MASK	vmovdqu8
 # define VMOVU	vmovdqu64
 # define VPCMP	vpcmpub
 # define VPTEST	vptestmb
@@ -62,12 +63,39 @@  ENTRY_P2ALIGN (MEMCMPEQ, 6)
 	movl	%edx, %edx
 # endif
 	cmp	$VEC_SIZE, %RDX_LP
-	jb	L(less_vec)
+	/* Fall through for [0, VEC_SIZE] as its the hottest.  */
+	ja	L(more_1x_vec)
+
+	/* Create mask of bytes that are guranteed to be valid because
+	   of length (edx). Using masked movs allows us to skip checks for
+	   page crosses/zero size.  */
+	movl	$-1, %ecx
+	bzhil	%edx, %ecx, %ecx
+	kmovd	%ecx, %k2
+
+	/* Use masked loads as VEC_SIZE could page cross where length
+	   (edx) would not.  */
+	VMOVU_MASK (%rsi), %YMM2{%k2}
+	VPCMP	$4,(%rdi), %YMM2, %k1{%k2}
+	kmovd	%k1, %eax
+	ret
 
-	/* From VEC to 2 * VEC.  No branch when size == VEC_SIZE.  */
+
+L(last_1x_vec):
+	VMOVU	-(VEC_SIZE * 1)(%rsi, %rdx), %YMM1
+	VPCMP	$4, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %k1
+	kmovd	%k1, %eax
+L(return_neq0):
+	ret
+
+
+
+	.p2align 4
+L(more_1x_vec):
+	/* From VEC + 1 to 2 * VEC.  */
 	VMOVU	(%rsi), %YMM1
 	/* Use compare not equals to directly check for mismatch.  */
-	VPCMP	$4, (%rdi), %YMM1, %k1
+	VPCMP	$4,(%rdi), %YMM1, %k1
 	kmovd	%k1, %eax
 	testl	%eax, %eax
 	jnz	L(return_neq0)
@@ -88,13 +116,13 @@  ENTRY_P2ALIGN (MEMCMPEQ, 6)
 
 	/* Check third and fourth VEC no matter what.  */
 	VMOVU	(VEC_SIZE * 2)(%rsi), %YMM3
-	VPCMP	$4, (VEC_SIZE * 2)(%rdi), %YMM3, %k1
+	VPCMP	$4,(VEC_SIZE * 2)(%rdi), %YMM3, %k1
 	kmovd	%k1, %eax
 	testl	%eax, %eax
 	jnz	L(return_neq0)
 
 	VMOVU	(VEC_SIZE * 3)(%rsi), %YMM4
-	VPCMP	$4, (VEC_SIZE * 3)(%rdi), %YMM4, %k1
+	VPCMP	$4,(VEC_SIZE * 3)(%rdi), %YMM4, %k1
 	kmovd	%k1, %eax
 	testl	%eax, %eax
 	jnz	L(return_neq0)
@@ -132,66 +160,6 @@  ENTRY_P2ALIGN (MEMCMPEQ, 6)
 	/* Compare YMM4 with 0. If any 1s s1 and s2 don't match.  */
 	VPTEST	%YMM4, %YMM4, %k1
 	kmovd	%k1, %eax
-L(return_neq0):
-	ret
-
-	/* Fits in padding needed to .p2align 5 L(less_vec).  */
-L(last_1x_vec):
-	VMOVU	-(VEC_SIZE * 1)(%rsi, %rdx), %YMM1
-	VPCMP	$4, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %k1
-	kmovd	%k1, %eax
-	ret
-
-	/* NB: p2align 5 here will ensure the L(loop_4x_vec) is also 32
-	   byte aligned.  */
-	.p2align 5
-L(less_vec):
-	/* Check if one or less char. This is necessary for size = 0 but
-	   is also faster for size = 1.  */
-	cmpl	$1, %edx
-	jbe	L(one_or_less)
-
-	/* Check if loading one VEC from either s1 or s2 could cause a
-	   page cross. This can have false positives but is by far the
-	   fastest method.  */
-	movl	%edi, %eax
-	orl	%esi, %eax
-	andl	$(PAGE_SIZE - 1), %eax
-	cmpl	$(PAGE_SIZE - VEC_SIZE), %eax
-	jg	L(page_cross_less_vec)
-
-	/* No page cross possible.  */
-	VMOVU	(%rsi), %YMM2
-	VPCMP	$4, (%rdi), %YMM2, %k1
-	kmovd	%k1, %eax
-	/* Result will be zero if s1 and s2 match. Otherwise first set
-	   bit will be first mismatch.  */
-	bzhil	%edx, %eax, %eax
-	ret
-
-	/* Relatively cold but placing close to L(less_vec) for 2 byte
-	   jump encoding.  */
-	.p2align 4
-L(one_or_less):
-	jb	L(zero)
-	movzbl	(%rsi), %ecx
-	movzbl	(%rdi), %eax
-	subl	%ecx, %eax
-	/* No ymm register was touched.  */
-	ret
-	/* Within the same 16 byte block is L(one_or_less).  */
-L(zero):
-	xorl	%eax, %eax
-	ret
-
-	.p2align 4
-L(last_2x_vec):
-	VMOVU	-(VEC_SIZE * 2)(%rsi, %rdx), %YMM1
-	vpxorq	-(VEC_SIZE * 2)(%rdi, %rdx), %YMM1, %YMM1
-	VMOVU	-(VEC_SIZE * 1)(%rsi, %rdx), %YMM2
-	vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %YMM2
-	VPTEST	%YMM2, %YMM2, %k1
-	kmovd	%k1, %eax
 	ret
 
 	.p2align 4
@@ -211,7 +179,7 @@  L(loop_4x_vec):
 	vpxorq	(%rdi), %YMM1, %YMM1
 
 	VMOVU	VEC_SIZE(%rsi, %rdi), %YMM2
-	vpternlogd $0xde, (VEC_SIZE)(%rdi), %YMM1, %YMM2
+	vpternlogd $0xde,(VEC_SIZE)(%rdi), %YMM1, %YMM2
 
 	VMOVU	(VEC_SIZE * 2)(%rsi, %rdi), %YMM3
 	vpxorq	(VEC_SIZE * 2)(%rdi), %YMM3, %YMM3
@@ -238,7 +206,7 @@  L(loop_4x_vec):
 	VMOVU	(VEC_SIZE * 2)(%rsi, %rdx), %YMM3
 	/* Ternary logic to xor (VEC_SIZE * 2)(%rdx) with YMM3 while
 	   oring with YMM4. Result is stored in YMM4.  */
-	vpternlogd $0xf6, (VEC_SIZE * 2)(%rdx), %YMM3, %YMM4
+	vpternlogd $0xf6,(VEC_SIZE * 2)(%rdx), %YMM3, %YMM4
 	cmpl	$(VEC_SIZE * 2), %edi
 	jae	L(8x_last_2x_vec)
 
@@ -256,68 +224,16 @@  L(8x_last_2x_vec):
 L(return_neq2):
 	ret
 
-	/* Relatively cold case as page cross are unexpected.  */
-	.p2align 4
-L(page_cross_less_vec):
-	cmpl	$16, %edx
-	jae	L(between_16_31)
-	cmpl	$8, %edx
-	ja	L(between_9_15)
-	cmpl	$4, %edx
-	jb	L(between_2_3)
-	/* From 4 to 8 bytes.  No branch when size == 4.  */
-	movl	(%rdi), %eax
-	subl	(%rsi), %eax
-	movl	-4(%rdi, %rdx), %ecx
-	movl	-4(%rsi, %rdx), %edi
-	subl	%edi, %ecx
-	orl	%ecx, %eax
-	ret
-
-	.p2align 4,, 8
-L(between_16_31):
-	/* From 16 to 31 bytes.  No branch when size == 16.  */
-
-	/* Safe to use xmm[0, 15] as no vzeroupper is needed so RTM safe.
-	 */
-	vmovdqu	(%rsi), %xmm1
-	vpcmpeqb (%rdi), %xmm1, %xmm1
-	vmovdqu	-16(%rsi, %rdx), %xmm2
-	vpcmpeqb -16(%rdi, %rdx), %xmm2, %xmm2
-	vpand	%xmm1, %xmm2, %xmm2
-	vpmovmskb %xmm2, %eax
-	notw	%ax
-	/* No ymm register was touched.  */
-	ret
-
 	.p2align 4,, 8
-L(between_9_15):
-	/* From 9 to 15 bytes.  */
-	movq	(%rdi), %rax
-	subq	(%rsi), %rax
-	movq	-8(%rdi, %rdx), %rcx
-	movq	-8(%rsi, %rdx), %rdi
-	subq	%rdi, %rcx
-	orq	%rcx, %rax
-	/* edx is guranteed to be a non-zero int.  */
-	cmovnz	%edx, %eax
-	ret
-
-	/* Don't align. This is cold and aligning here will cause code
-	   to spill into next cache line.  */
-L(between_2_3):
-	/* From 2 to 3 bytes.  No branch when size == 2.  */
-	movzwl	(%rdi), %eax
-	movzwl	(%rsi), %ecx
-	subl	%ecx, %eax
-	movzbl	-1(%rdi, %rdx), %ecx
-	/* All machines that support evex will insert a "merging uop"
-	   avoiding any serious partial register stalls.  */
-	subb	-1(%rsi, %rdx), %cl
-	orl	%ecx, %eax
-	/* No ymm register was touched.  */
+L(last_2x_vec):
+	VMOVU	-(VEC_SIZE * 2)(%rsi, %rdx), %YMM1
+	vpxorq	-(VEC_SIZE * 2)(%rdi, %rdx), %YMM1, %YMM1
+	VMOVU	-(VEC_SIZE * 1)(%rsi, %rdx), %YMM2
+	vpternlogd $0xde, -(VEC_SIZE * 1)(%rdi, %rdx), %YMM1, %YMM2
+	VPTEST	%YMM2, %YMM2, %k1
+	kmovd	%k1, %eax
 	ret
 
-    /* 4 Bytes from next cache line. */
+    /* 1 Bytes from next cache line. */
 END (MEMCMPEQ)
 #endif