[RFC] aarch64: Fold lsl+lsr+orr to rev for half-width shifts

Message ID 43813c0a-ea45-4d70-bdb7-751613f895df@nvidia.com
State New
Headers
Series [RFC] aarch64: Fold lsl+lsr+orr to rev for half-width shifts |

Checks

Context Check Description
linaro-tcwg-bot/tcwg_gcc_build--master-arm warning Skipped upon request
linaro-tcwg-bot/tcwg_gcc_build--master-aarch64 warning Skipped upon request

Commit Message

Dhruv Chawla Nov. 27, 2024, 4:53 a.m. UTC
  This patch modifies the intrinsic expanders to expand svlsl and svlsr to
unpredicated forms when the predicate is a ptrue. It also folds the
following pattern:

   lsl <y>, <x>, <shift>
   lsr <z>, <x>, <shift>
   orr <r>, <y>, <z>

to:

   revb/h/w <r>, <x>

when the shift amount is equal to half the bitwidth of the <x>
register.

This relies on the RTL combiners combining the "ior (ashift, ashiftrt)"
pattern to a "rotate" when the shift amount is half the element width.
In the case of the shift amount being 8, a "bswap" is generated.

While this works well, the problem is that the matchers for instructions
like SRA and ADR expect the shifts to be in an unspec form. So, to keep
matching the patterns when the unpredicated instructions are generated,
they have to be duplicated to also accept the unpredicated form. Looking
for feedback on whether this is a good way to proceed with this problem
or how to do this in a better way.

The patch was bootstrapped and regtested on aarch64-linux-gnu.
  

Patch

diff --git a/gcc/config/aarch64/aarch64-sve-builtins-base.cc b/gcc/config/aarch64/aarch64-sve-builtins-base.cc
index 87e9909b55a..d91182b6454 100644
--- a/gcc/config/aarch64/aarch64-sve-builtins-base.cc
+++ b/gcc/config/aarch64/aarch64-sve-builtins-base.cc
@@ -1947,6 +1947,33 @@  public:
   {
     return f.fold_const_binary (LSHIFT_EXPR);
   }
+
+  rtx expand (function_expander &e) const override
+  {
+    tree pred = TREE_OPERAND (e.call_expr, 3);
+    if (is_ptrue (pred, GET_MODE_UNIT_SIZE (e.result_mode ())))
+      return e.use_unpred_insn (e.direct_optab_handler (ashl_optab));
+    return rtx_code_function::expand (e);
+  }
+};
+
+class svlsr_impl : public rtx_code_function
+{
+public:
+  CONSTEXPR svlsr_impl () : rtx_code_function (LSHIFTRT, LSHIFTRT) {}
+
+  gimple *fold (gimple_folder &f) const override
+  {
+    return f.fold_const_binary (RSHIFT_EXPR);
+  }
+
+  rtx expand (function_expander &e) const override
+  {
+    tree pred = TREE_OPERAND (e.call_expr, 3);
+    if (is_ptrue (pred, GET_MODE_UNIT_SIZE (e.result_mode ())))
+      return e.use_unpred_insn (e.direct_optab_handler (lshr_optab));
+    return rtx_code_function::expand (e);
+  }
 };
 
 class svmad_impl : public function_base
@@ -3315,7 +3342,7 @@  FUNCTION (svldnt1, svldnt1_impl,)
 FUNCTION (svlen, svlen_impl,)
 FUNCTION (svlsl, svlsl_impl,)
 FUNCTION (svlsl_wide, shift_wide, (ASHIFT, UNSPEC_ASHIFT_WIDE))
-FUNCTION (svlsr, rtx_code_function, (LSHIFTRT, LSHIFTRT))
+FUNCTION (svlsr, svlsr_impl, )
 FUNCTION (svlsr_wide, shift_wide, (LSHIFTRT, UNSPEC_LSHIFTRT_WIDE))
 FUNCTION (svmad, svmad_impl,)
 FUNCTION (svmax, rtx_code_function, (SMAX, UMAX, UNSPEC_COND_FMAX,
diff --git a/gcc/config/aarch64/aarch64-sve.md b/gcc/config/aarch64/aarch64-sve.md
index 9afd11d3476..3d0bd3b8a67 100644
--- a/gcc/config/aarch64/aarch64-sve.md
+++ b/gcc/config/aarch64/aarch64-sve.md
@@ -3233,6 +3233,55 @@ 
 ;; - REVW
 ;; -------------------------------------------------------------------------
 
+(define_insn_and_split "*v_rev<mode>"
+  [(set (match_operand:SVE_FULL_HSDI 0 "register_operand" "=w")
+	(rotate:SVE_FULL_HSDI
+	  (match_operand:SVE_FULL_HSDI 1 "register_operand" "w")
+	  (match_operand:SVE_FULL_HSDI 2 "aarch64_constant_vector_operand")))]
+  "TARGET_SVE"
+  "#"
+  "&& !reload_completed"
+  [(set (match_dup 3)
+	(ashift:SVE_FULL_HSDI (match_dup 1)
+			      (match_dup 2)))
+   (set (match_dup 0)
+	(plus:SVE_FULL_HSDI
+	  (lshiftrt:SVE_FULL_HSDI (match_dup 1)
+				  (match_dup 4))
+	  (match_dup 3)))]
+  {
+    if (aarch64_emit_opt_vec_rotate (operands[0], operands[1], operands[2]))
+      DONE;
+
+    operands[3] = gen_reg_rtx (<MODE>mode);
+    rtx shift_amount = unwrap_const_vec_duplicate (operands[2]);
+    int bitwidth = GET_MODE_UNIT_BITSIZE (<MODE>mode);
+    operands[4] = aarch64_simd_gen_const_vector_dup (<MODE>mode,
+						     bitwidth - INTVAL (shift_amount));
+  }
+)
+
+;; The RTL combiners are able to combine "ior (ashift, ashiftrt)" to a "bswap".
+;; Match that as well.
+(define_insn_and_split "*v_revvnx8hi"
+  [(set (match_operand:VNx8HI 0 "register_operand" "=w")
+	(bswap:VNx8HI
+	  (match_operand 1 "register_operand" "w")))]
+  "TARGET_SVE"
+  "#"
+  "&& !reload_completed"
+  [(set (match_dup 0)
+	(unspec:VNx8HI
+	  [(match_dup 2)
+	   (unspec:VNx8HI
+	     [(match_dup 1)]
+	     UNSPEC_REVB)]
+	  UNSPEC_PRED_X))]
+  {
+    operands[2] = aarch64_ptrue_reg (VNx8BImode);
+  }
+)
+
 ;; Predicated integer unary operations.
 (define_insn "@aarch64_pred_<optab><mode>"
   [(set (match_operand:SVE_FULL_I 0 "register_operand")
@@ -4163,6 +4212,17 @@ 
   }
 )
 
+(define_expand "@aarch64_adr<mode>_shift_unpred"
+  [(set (match_operand:SVE_FULL_SDI 0 "register_operand")
+	(plus:SVE_FULL_SDI
+	  (ashift:SVE_FULL_SDI
+	    (match_operand:SVE_FULL_SDI 2 "register_operand")
+	    (match_operand:SVE_FULL_SDI 3 "const_1_to_3_operand"))
+	  (match_operand:SVE_FULL_SDI 1 "register_operand")))]
+  "TARGET_SVE && TARGET_NON_STREAMING"
+  {}
+)
+
 (define_insn_and_rewrite "*aarch64_adr<mode>_shift"
   [(set (match_operand:SVE_24I 0 "register_operand" "=w")
 	(plus:SVE_24I
@@ -4181,6 +4241,17 @@ 
   }
 )
 
+(define_insn "*aarch64_adr<mode>_shift_unpred"
+  [(set (match_operand:SVE_24I 0 "register_operand" "=w")
+	(plus:SVE_24I
+	  (ashift:SVE_24I
+	    (match_operand:SVE_24I 2 "register_operand" "w")
+	    (match_operand:SVE_24I 3 "const_1_to_3_operand"))
+	  (match_operand:SVE_24I 1 "register_operand" "w")))]
+  "TARGET_SVE && TARGET_NON_STREAMING"
+  "adr\t%0.<Vctype>, [%1.<Vctype>, %2.<Vctype>, lsl %3]"
+)
+
 ;; Same, but with the index being sign-extended from the low 32 bits.
 (define_insn_and_rewrite "*aarch64_adr_shift_sxtw"
   [(set (match_operand:VNx2DI 0 "register_operand" "=w")
@@ -4205,6 +4276,26 @@ 
   }
 )
 
+(define_insn_and_rewrite "*aarch64_adr_shift_sxtw_unpred"
+  [(set (match_operand:VNx2DI 0 "register_operand" "=w")
+	(plus:VNx2DI
+	  (ashift:VNx2DI
+	    (unspec:VNx2DI
+	      [(match_operand 4)
+	       (sign_extend:VNx2DI
+		 (truncate:VNx2SI
+		   (match_operand:VNx2DI 2 "register_operand" "w")))]
+	     UNSPEC_PRED_X)
+	    (match_operand:VNx2DI 3 "const_1_to_3_operand"))
+	  (match_operand:VNx2DI 1 "register_operand" "w")))]
+  "TARGET_SVE && TARGET_NON_STREAMING"
+  "adr\t%0.d, [%1.d, %2.d, sxtw %3]"
+  "&& !CONSTANT_P (operands[4])"
+  {
+    operands[4] = CONSTM1_RTX (VNx2BImode);
+  }
+)
+
 ;; Same, but with the index being zero-extended from the low 32 bits.
 (define_insn_and_rewrite "*aarch64_adr_shift_uxtw"
   [(set (match_operand:VNx2DI 0 "register_operand" "=w")
@@ -4226,6 +4317,19 @@ 
   }
 )
 
+(define_insn "*aarch64_adr_shift_uxtw_unpred"
+  [(set (match_operand:VNx2DI 0 "register_operand" "=w")
+	(plus:VNx2DI
+	  (ashift:VNx2DI
+	    (and:VNx2DI
+	      (match_operand:VNx2DI 2 "register_operand" "w")
+	      (match_operand:VNx2DI 4 "aarch64_sve_uxtw_immediate"))
+	    (match_operand:VNx2DI 3 "const_1_to_3_operand"))
+	  (match_operand:VNx2DI 1 "register_operand" "w")))]
+  "TARGET_SVE && TARGET_NON_STREAMING"
+  "adr\t%0.d, [%1.d, %2.d, uxtw %3]"
+)
+
 ;; -------------------------------------------------------------------------
 ;; ---- [INT] Absolute difference
 ;; -------------------------------------------------------------------------
@@ -4804,6 +4908,9 @@ 
 
 ;; Unpredicated shift by a scalar, which expands into one of the vector
 ;; shifts below.
+;;
+;; The unpredicated form is emitted only when the shift amount is a constant
+;; value that is valid for the shift being carried out.
 (define_expand "<ASHIFT:optab><mode>3"
   [(set (match_operand:SVE_I 0 "register_operand")
 	(ASHIFT:SVE_I
@@ -4811,20 +4918,29 @@ 
 	  (match_operand:<VEL> 2 "general_operand")))]
   "TARGET_SVE"
   {
-    rtx amount;
+    rtx amount = NULL_RTX;
     if (CONST_INT_P (operands[2]))
       {
-	amount = gen_const_vec_duplicate (<MODE>mode, operands[2]);
-	if (!aarch64_sve_<lr>shift_operand (operands[2], <MODE>mode))
-	  amount = force_reg (<MODE>mode, amount);
+	if (aarch64_simd_shift_imm_p (operands[2], <MODE>mode, <optab>_optab == ashl_optab))
+	  operands[2] = aarch64_simd_gen_const_vector_dup (<MODE>mode, INTVAL (operands[2]));
+	else
+	  {
+	    amount = gen_const_vec_duplicate (<MODE>mode, operands[2]);
+	    if (!aarch64_sve_<lr>shift_operand (operands[2], <MODE>mode))
+	      amount = force_reg (<MODE>mode, amount);
+	  }
       }
     else
       {
 	amount = convert_to_mode (<VEL>mode, operands[2], 0);
 	amount = expand_vector_broadcast (<MODE>mode, amount);
       }
-    emit_insn (gen_v<optab><mode>3 (operands[0], operands[1], amount));
-    DONE;
+
+    if (amount)
+      {
+	emit_insn (gen_v<optab><mode>3 (operands[0], operands[1], amount));
+	DONE;
+      }
   }
 )
 
@@ -4868,27 +4984,27 @@ 
   ""
 )
 
-;; Unpredicated shift operations by a constant (post-RA only).
+;; Unpredicated shift operations by a constant.
 ;; These are generated by splitting a predicated instruction whose
 ;; predicate is unused.
-(define_insn "*post_ra_v_ashl<mode>3"
+(define_insn "*v_ashl<mode>3"
   [(set (match_operand:SVE_I 0 "register_operand")
 	(ashift:SVE_I
 	  (match_operand:SVE_I 1 "register_operand")
 	  (match_operand:SVE_I 2 "aarch64_simd_lshift_imm")))]
-  "TARGET_SVE && reload_completed"
+  "TARGET_SVE"
   {@ [ cons: =0 , 1 , 2   ]
      [ w	, w , vs1 ] add\t%0.<Vetype>, %1.<Vetype>, %1.<Vetype>
      [ w	, w , Dl  ] lsl\t%0.<Vetype>, %1.<Vetype>, #%2
   }
 )
 
-(define_insn "*post_ra_v_<optab><mode>3"
+(define_insn "*v_<optab><mode>3"
   [(set (match_operand:SVE_I 0 "register_operand" "=w")
 	(SHIFTRT:SVE_I
 	  (match_operand:SVE_I 1 "register_operand" "w")
 	  (match_operand:SVE_I 2 "aarch64_simd_rshift_imm")))]
-  "TARGET_SVE && reload_completed"
+  "TARGET_SVE"
   "<shift>\t%0.<Vetype>, %1.<Vetype>, #%2"
 )
 
diff --git a/gcc/config/aarch64/aarch64-sve2.md b/gcc/config/aarch64/aarch64-sve2.md
index 66affa85d36..b3fb0460b70 100644
--- a/gcc/config/aarch64/aarch64-sve2.md
+++ b/gcc/config/aarch64/aarch64-sve2.md
@@ -1876,6 +1876,16 @@ 
   }
 )
 
+(define_expand "@aarch64_sve_add_<sve_int_op><mode>_unpred"
+  [(set (match_operand:SVE_FULL_I 0 "register_operand")
+	(plus:SVE_FULL_I
+	  (SHIFTRT:SVE_FULL_I
+	    (match_operand:SVE_FULL_I 2 "register_operand")
+	    (match_operand:SVE_FULL_I 3 "aarch64_simd_rshift_imm"))
+	  (match_operand:SVE_FULL_I 1 "register_operand")))]
+  "TARGET_SVE2"
+)
+
 ;; Pattern-match SSRA and USRA as a predicated operation whose predicate
 ;; isn't needed.
 (define_insn_and_rewrite "*aarch64_sve2_sra<mode>"
@@ -1899,6 +1909,20 @@ 
   }
 )
 
+(define_insn "*aarch64_sve2_sra<mode>_unpred"
+  [(set (match_operand:SVE_FULL_I 0 "register_operand")
+	(plus:SVE_FULL_I
+	  (SHIFTRT:SVE_FULL_I
+	    (match_operand:SVE_FULL_I 2 "register_operand")
+	    (match_operand:SVE_FULL_I 3 "aarch64_simd_rshift_imm"))
+	 (match_operand:SVE_FULL_I 1 "register_operand")))]
+  "TARGET_SVE2"
+  {@ [ cons: =0 , 1 , 2 ; attrs: movprfx ]
+     [ w        , 0 , w ; *              ] <sra_op>sra\t%0.<Vetype>, %2.<Vetype>, #%3
+     [ ?&w      , w , w ; yes            ] movprfx\t%0, %1\;<sra_op>sra\t%0.<Vetype>, %2.<Vetype>, #%3
+  }
+)
+
 ;; SRSRA and URSRA.
 (define_insn "@aarch64_sve_add_<sve_int_op><mode>"
   [(set (match_operand:SVE_FULL_I 0 "register_operand")
@@ -2539,6 +2563,18 @@ 
   "addhnb\t%0.<Ventype>, %2.<Vetype>, %3.<Vetype>"
 )
 
+(define_insn "*bitmask_shift_plus_unpred<mode>"
+  [(set (match_operand:SVE_FULL_HSDI 0 "register_operand" "=w")
+	(lshiftrt:SVE_FULL_HSDI
+	  (plus:SVE_FULL_HSDI
+	    (match_operand:SVE_FULL_HSDI 1 "register_operand" "w")
+	    (match_operand:SVE_FULL_HSDI 2 "register_operand" "w"))
+	  (match_operand:SVE_FULL_HSDI 3
+	    "aarch64_simd_shift_imm_vec_exact_top" "")))]
+  "TARGET_SVE2"
+  "addhnb\t%0.<Ventype>, %1.<Vetype>, %2.<Vetype>"
+)
+
 ;; -------------------------------------------------------------------------
 ;; ---- [INT] Narrowing right shifts
 ;; -------------------------------------------------------------------------
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/shift_rev_1.c b/gcc/testsuite/gcc.target/aarch64/sve/shift_rev_1.c
new file mode 100644
index 00000000000..3a30f80d152
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/sve/shift_rev_1.c
@@ -0,0 +1,83 @@ 
+/* { dg-do compile } */
+/* { dg-options "-O3 -march=armv8.2-a+sve" } */
+/* { dg-final { check-function-bodies "**" "" "" } } */
+
+#include <arm_sve.h>
+
+/*
+** ror32_sve_lsl_imm:
+**	ptrue	p3.b, all
+**	revw	z0.d, p3/m, z0.d
+**	ret
+*/
+svuint64_t
+ror32_sve_lsl_imm (svuint64_t r)
+{
+  return svorr_u64_z (svptrue_b64 (), svlsl_n_u64_z (svptrue_b64 (), r, 32),
+		      svlsr_n_u64_z (svptrue_b64 (), r, 32));
+}
+
+/*
+** ror32_sve_lsl_operand:
+**	ptrue	p3.b, all
+**	revw	z0.d, p3/m, z0.d
+**	ret
+*/
+svuint64_t
+ror32_sve_lsl_operand (svuint64_t r)
+{
+  svbool_t pt = svptrue_b64 ();
+  return svorr_u64_z (pt, svlsl_n_u64_z (pt, r, 32), svlsr_n_u64_z (pt, r, 32));
+}
+
+/*
+** ror16_sve_lsl_imm:
+**	ptrue	p3.b, all
+**	revh	z0.s, p3/m, z0.s
+**	ret
+*/
+svuint32_t
+ror16_sve_lsl_imm (svuint32_t r)
+{
+  return svorr_u32_z (svptrue_b32 (), svlsl_n_u32_z (svptrue_b32 (), r, 16),
+		      svlsr_n_u32_z (svptrue_b32 (), r, 16));
+}
+
+/*
+** ror16_sve_lsl_operand:
+**	ptrue	p3.b, all
+**	revh	z0.s, p3/m, z0.s
+**	ret
+*/
+svuint32_t
+ror16_sve_lsl_operand (svuint32_t r)
+{
+  svbool_t pt = svptrue_b32 ();
+  return svorr_u32_z (pt, svlsl_n_u32_z (pt, r, 16), svlsr_n_u32_z (pt, r, 16));
+}
+
+/*
+** ror8_sve_lsl_imm:
+**	ptrue	p3.b, all
+**	revb	z0.h, p3/m, z0.h
+**	ret
+*/
+svuint16_t
+ror8_sve_lsl_imm (svuint16_t r)
+{
+  return svorr_u16_z (svptrue_b16 (), svlsl_n_u16_z (svptrue_b16 (), r, 8),
+		      svlsr_n_u16_z (svptrue_b16 (), r, 8));
+}
+
+/*
+** ror8_sve_lsl_operand:
+**	ptrue	p3.b, all
+**	revb	z0.h, p3/m, z0.h
+**	ret
+*/
+svuint16_t
+ror8_sve_lsl_operand (svuint16_t r)
+{
+  svbool_t pt = svptrue_b16 ();
+  return svorr_u16_z (pt, svlsl_n_u16_z (pt, r, 8), svlsr_n_u16_z (pt, r, 8));
+}
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/shift_rev_2.c b/gcc/testsuite/gcc.target/aarch64/sve/shift_rev_2.c
new file mode 100644
index 00000000000..89d5a8a8b3e
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/sve/shift_rev_2.c
@@ -0,0 +1,63 @@ 
+/* { dg-do compile } */
+/* { dg-options "-O3 -march=armv8.2-a+sve" } */
+
+#include <arm_sve.h>
+
+#define PTRUE_B(BITWIDTH) svptrue_b##BITWIDTH ()
+
+#define ROR_SVE_LSL(NAME, INPUT_TYPE, SHIFT_AMOUNT, BITWIDTH)                  \
+  INPUT_TYPE                                                                   \
+  NAME##_imm (INPUT_TYPE r)                                                    \
+  {                                                                            \
+    return svorr_u##BITWIDTH##_z (PTRUE_B (BITWIDTH),                          \
+				  svlsl_n_u##BITWIDTH##_z (PTRUE_B (BITWIDTH), \
+							   r, SHIFT_AMOUNT),   \
+				  svlsr_n_u##BITWIDTH##_z (PTRUE_B (BITWIDTH), \
+							   r, SHIFT_AMOUNT));  \
+  }                                                                            \
+                                                                               \
+  INPUT_TYPE                                                                   \
+  NAME##_operand (INPUT_TYPE r)                                                \
+  {                                                                            \
+    svbool_t pt = PTRUE_B (BITWIDTH);                                          \
+    return svorr_u##BITWIDTH##_z (                                             \
+      pt, svlsl_n_u##BITWIDTH##_z (pt, r, SHIFT_AMOUNT),                       \
+      svlsr_n_u##BITWIDTH##_z (pt, r, SHIFT_AMOUNT));                          \
+  }
+
+/* Make sure that the pattern doesn't match incorrect bit-widths, eg. a shift of
+   8 matching the 32-bit mode.  */
+
+ROR_SVE_LSL (higher_ror32, svuint64_t, 64, 64);
+ROR_SVE_LSL (higher_ror16, svuint32_t, 32, 32);
+ROR_SVE_LSL (higher_ror8, svuint16_t, 16, 16);
+
+ROR_SVE_LSL (lower_ror32, svuint64_t, 16, 64);
+ROR_SVE_LSL (lower_ror16, svuint32_t, 8, 32);
+ROR_SVE_LSL (lower_ror8, svuint16_t, 4, 16);
+
+/* Check off-by-one cases.  */
+
+ROR_SVE_LSL (off_1_high_ror32, svuint64_t, 33, 64);
+ROR_SVE_LSL (off_1_high_ror16, svuint32_t, 17, 32);
+ROR_SVE_LSL (off_1_high_ror8, svuint16_t, 9, 16);
+
+ROR_SVE_LSL (off_1_low_ror32, svuint64_t, 31, 64);
+ROR_SVE_LSL (off_1_low_ror16, svuint32_t, 15, 32);
+ROR_SVE_LSL (off_1_low_ror8, svuint16_t, 7, 16);
+
+/* Check out of bounds cases.  */
+
+ROR_SVE_LSL (oob_ror32, svuint64_t, 65, 64);
+ROR_SVE_LSL (oob_ror16, svuint32_t, 33, 32);
+ROR_SVE_LSL (oob_ror8, svuint16_t, 17, 16);
+
+/* Check zero case.  */
+
+ROR_SVE_LSL (zero_ror32, svuint64_t, 0, 64);
+ROR_SVE_LSL (zero_ror16, svuint32_t, 0, 32);
+ROR_SVE_LSL (zero_ror8, svuint16_t, 0, 16);
+
+/* { dg-final { scan-assembler-times "revb" 0 } } */
+/* { dg-final { scan-assembler-times "revh" 0 } } */
+/* { dg-final { scan-assembler-times "revw" 0 } } */