[v4,5/5] AArch64: Improve A64FX memset medium loops

Message ID VE1PR08MB5599EFE4EC93A5A2D268893383F69@VE1PR08MB5599.eurprd08.prod.outlook.com
State Committed
Commit a5db6a5cae6a92d1675c013e5c8d972768721576
Delegated to: Szabolcs Nagy
Headers
Series [v4,1/5] AArch64: Improve A64FX memset for small sizes |

Commit Message

Wilco Dijkstra Aug. 9, 2021, 1:15 p.m. UTC
  v4: minor loop change

Simplify the code for memsets smaller than L1. Improve the unroll8 and L1_prefetch loops.

---
  

Comments

Szabolcs Nagy Aug. 10, 2021, 9:44 a.m. UTC | #1
The 08/09/2021 13:15, Wilco Dijkstra via Libc-alpha wrote:
> v4: minor loop change
> 
> Simplify the code for memsets smaller than L1. Improve the unroll8 and L1_prefetch loops.

OK to commit, but keep

Reviewed-by: Naohiro Tamura <naohirot@fujitsu.com>

(further tweaks can go into follwup commits.)

> 
> ---
> 
> diff --git a/sysdeps/aarch64/multiarch/memset_a64fx.S b/sysdeps/aarch64/multiarch/memset_a64fx.S
> index 89dba912588c243e67a9527a56b4d3a44659d542..318c6350a31e0fad788b5f2139de645ddc51493f 100644
> --- a/sysdeps/aarch64/multiarch/memset_a64fx.S
> +++ b/sysdeps/aarch64/multiarch/memset_a64fx.S
> @@ -30,7 +30,6 @@
>  #define L2_SIZE         (8*1024*1024)	// L2 8MB
>  #define CACHE_LINE_SIZE	256
>  #define PF_DIST_L1	(CACHE_LINE_SIZE * 16)	// Prefetch distance L1
> -#define rest		x2
>  #define vector_length	x9
>  
>  #if HAVE_AARCH64_SVE_ASM
> @@ -89,29 +88,19 @@ ENTRY (MEMSET)
>  
>  	.p2align 4
>  L(vl_agnostic): // VL Agnostic
> -	mov	rest, count
>  	mov	dst, dstin
> -	add	dstend, dstin, count
> -	// if rest >= L2_SIZE && vector_length == 64 then L(L2)
> -	mov	tmp1, 64
> -	cmp	rest, L2_SIZE
> -	ccmp	vector_length, tmp1, 0, cs
> -	b.eq	L(L2)
> -	// if rest >= L1_SIZE && vector_length == 64 then L(L1_prefetch)
> -	cmp	rest, L1_SIZE
> -	ccmp	vector_length, tmp1, 0, cs
> -	b.eq	L(L1_prefetch)
> -
> +	cmp	count, L1_SIZE
> +	b.hi	L(L1_prefetch)
>  
> +	// count >= 8 * vector_length
>  L(unroll8):
> -	lsl	tmp1, vector_length, 3
> -	.p2align 3
> -1:	cmp	rest, tmp1
> -	b.cc	L(last)
> -	st1b_unroll
> +	sub	count, count, tmp1
> +	.p2align 4
> +1:	st1b_unroll 0, 7
>  	add	dst, dst, tmp1
> -	sub	rest, rest, tmp1
> -	b	1b
> +	subs	count, count, tmp1
> +	b.hi	1b
> +	add	count, count, tmp1
>  
>  L(last):
>  	cmp	count, vector_length, lsl 1
> @@ -129,18 +118,22 @@ L(last):
>  	st1b	z0.b, p0, [dstend, -1, mul vl]
>  	ret
>  
> -L(L1_prefetch): // if rest >= L1_SIZE
> +	// count >= L1_SIZE
>  	.p2align 3
> +L(L1_prefetch):
> +	cmp	count, L2_SIZE
> +	b.hs	L(L2)
> +	cmp	vector_length, 64
> +	b.ne	L(unroll8)
>  1:	st1b_unroll 0, 3
>  	prfm	pstl1keep, [dst, PF_DIST_L1]
>  	st1b_unroll 4, 7
>  	prfm	pstl1keep, [dst, PF_DIST_L1 + CACHE_LINE_SIZE]
>  	add	dst, dst, CACHE_LINE_SIZE * 2
> -	sub	rest, rest, CACHE_LINE_SIZE * 2
> -	cmp	rest, L1_SIZE
> -	b.ge	1b
> -	cbnz	rest, L(unroll8)
> -	ret
> +	sub	count, count, CACHE_LINE_SIZE * 2
> +	cmp	count, PF_DIST_L1
> +	b.hs	1b
> +	b	L(unroll8)
>  
>  	// count >= L2_SIZE
>  	.p2align 3
> 

--
  

Patch

diff --git a/sysdeps/aarch64/multiarch/memset_a64fx.S b/sysdeps/aarch64/multiarch/memset_a64fx.S
index 89dba912588c243e67a9527a56b4d3a44659d542..318c6350a31e0fad788b5f2139de645ddc51493f 100644
--- a/sysdeps/aarch64/multiarch/memset_a64fx.S
+++ b/sysdeps/aarch64/multiarch/memset_a64fx.S
@@ -30,7 +30,6 @@ 
 #define L2_SIZE         (8*1024*1024)	// L2 8MB
 #define CACHE_LINE_SIZE	256
 #define PF_DIST_L1	(CACHE_LINE_SIZE * 16)	// Prefetch distance L1
-#define rest		x2
 #define vector_length	x9
 
 #if HAVE_AARCH64_SVE_ASM
@@ -89,29 +88,19 @@  ENTRY (MEMSET)
 
 	.p2align 4
 L(vl_agnostic): // VL Agnostic
-	mov	rest, count
 	mov	dst, dstin
-	add	dstend, dstin, count
-	// if rest >= L2_SIZE && vector_length == 64 then L(L2)
-	mov	tmp1, 64
-	cmp	rest, L2_SIZE
-	ccmp	vector_length, tmp1, 0, cs
-	b.eq	L(L2)
-	// if rest >= L1_SIZE && vector_length == 64 then L(L1_prefetch)
-	cmp	rest, L1_SIZE
-	ccmp	vector_length, tmp1, 0, cs
-	b.eq	L(L1_prefetch)
-
+	cmp	count, L1_SIZE
+	b.hi	L(L1_prefetch)
 
+	// count >= 8 * vector_length
 L(unroll8):
-	lsl	tmp1, vector_length, 3
-	.p2align 3
-1:	cmp	rest, tmp1
-	b.cc	L(last)
-	st1b_unroll
+	sub	count, count, tmp1
+	.p2align 4
+1:	st1b_unroll 0, 7
 	add	dst, dst, tmp1
-	sub	rest, rest, tmp1
-	b	1b
+	subs	count, count, tmp1
+	b.hi	1b
+	add	count, count, tmp1
 
 L(last):
 	cmp	count, vector_length, lsl 1
@@ -129,18 +118,22 @@  L(last):
 	st1b	z0.b, p0, [dstend, -1, mul vl]
 	ret
 
-L(L1_prefetch): // if rest >= L1_SIZE
+	// count >= L1_SIZE
 	.p2align 3
+L(L1_prefetch):
+	cmp	count, L2_SIZE
+	b.hs	L(L2)
+	cmp	vector_length, 64
+	b.ne	L(unroll8)
 1:	st1b_unroll 0, 3
 	prfm	pstl1keep, [dst, PF_DIST_L1]
 	st1b_unroll 4, 7
 	prfm	pstl1keep, [dst, PF_DIST_L1 + CACHE_LINE_SIZE]
 	add	dst, dst, CACHE_LINE_SIZE * 2
-	sub	rest, rest, CACHE_LINE_SIZE * 2
-	cmp	rest, L1_SIZE
-	b.ge	1b
-	cbnz	rest, L(unroll8)
-	ret
+	sub	count, count, CACHE_LINE_SIZE * 2
+	cmp	count, PF_DIST_L1
+	b.hs	1b
+	b	L(unroll8)
 
 	// count >= L2_SIZE
 	.p2align 3