[1/2] middle-end Handle difference between complex negations in SLP tree better (GCC 11 backport)

Message ID patch-15324-tamar@arm.com
State Committed
Headers
Series [1/2] middle-end Handle difference between complex negations in SLP tree better (GCC 11 backport) |

Commit Message

Tamar Christina Feb. 28, 2022, 11:29 a.m. UTC
  Hi All,

GCC 11 handled negations rather differently than GCC 12.  This difference
caused the previous backport to regress some of the conjugate cases that it
used to handle before.  The testsuite in GCC 11 wasn't as robust as that in
master so it didn't catch it.

The second patch in this series backports the testcases from master to GCC-11
to prevent this in the future.

This patch deals with the conjugate cases correctly by updating the detection
code to deal with the different order of operands.

For MUL the problem is that the presence of an ADD can cause the order of the
operands to flip, unlike in GCC 12.  So to handle this if we detect the shape
of a MUL but the data-flow check fails, we swap both operands and try again.

Since a * b == b * a this is fine and allows us to keep the df-check simple.
This doesn't cause a compile time issue either as most of the data will be in
the caches from the previous call.

Bootstrapped Regtested on aarch64-none-linux-gnu,
x86_64-pc-linux-gnu and no regressions on updated testsuite.

Ok for GCC 11?

Thanks,
Tamar

gcc/ChangeLog:

	* tree-vect-slp-patterns.c (vect_validate_multiplication): Correctly
	detect conjugate cases.
	(complex_mul_pattern::matches): Likewise.
	(complex_fma_pattern::matches): Move accumulator last as expected.
	(complex_fma_pattern::build): Likewise.
	(complex_fms_pattern::matches): Handle different conjugate form.

--- inline copy of patch -- 
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index a3bd90ff85b4ca5423a94388d480b66051a83e08..8b08a0f33dd1cd23ad9577243524c1feaa5e8ed9 100644


--
  

Comments

Richard Biener Feb. 28, 2022, 12:48 p.m. UTC | #1
On Mon, 28 Feb 2022, Tamar Christina wrote:

> Hi All,
> 
> GCC 11 handled negations rather differently than GCC 12.  This difference
> caused the previous backport to regress some of the conjugate cases that it
> used to handle before.  The testsuite in GCC 11 wasn't as robust as that in
> master so it didn't catch it.
> 
> The second patch in this series backports the testcases from master to GCC-11
> to prevent this in the future.
> 
> This patch deals with the conjugate cases correctly by updating the detection
> code to deal with the different order of operands.
> 
> For MUL the problem is that the presence of an ADD can cause the order of the
> operands to flip, unlike in GCC 12.  So to handle this if we detect the shape
> of a MUL but the data-flow check fails, we swap both operands and try again.
> 
> Since a * b == b * a this is fine and allows us to keep the df-check simple.
> This doesn't cause a compile time issue either as most of the data will be in
> the caches from the previous call.
> 
> Bootstrapped Regtested on aarch64-none-linux-gnu,
> x86_64-pc-linux-gnu and no regressions on updated testsuite.
> 
> Ok for GCC 11?

OK.

> Thanks,
> Tamar
> 
> gcc/ChangeLog:
> 
> 	* tree-vect-slp-patterns.c (vect_validate_multiplication): Correctly
> 	detect conjugate cases.
> 	(complex_mul_pattern::matches): Likewise.
> 	(complex_fma_pattern::matches): Move accumulator last as expected.
> 	(complex_fma_pattern::build): Likewise.
> 	(complex_fms_pattern::matches): Handle different conjugate form.
> 
> --- inline copy of patch -- 
> diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
> index a3bd90ff85b4ca5423a94388d480b66051a83e08..8b08a0f33dd1cd23ad9577243524c1feaa5e8ed9 100644
> --- a/gcc/tree-vect-slp-patterns.c
> +++ b/gcc/tree-vect-slp-patterns.c
> @@ -873,10 +873,8 @@ compatible_complex_nodes_p (slp_compat_nodes_map_t *compat_cache,
>  static inline bool
>  vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
>  			      slp_compat_nodes_map_t *compat_cache,
> -			      vec<slp_tree> &left_op,
> -			      vec<slp_tree> &right_op,
> -			      bool subtract,
> -			      enum _conj_status *_status)
> +			      vec<slp_tree> &left_op, vec<slp_tree> &right_op,
> +			      bool subtract, enum _conj_status *_status)
>  {
>    auto_vec<slp_tree> ops;
>    enum _conj_status stats = CONJ_NONE;
> @@ -902,29 +900,31 @@ vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
>  
>    /* Default to style and perm 0, most operations use this one.  */
>    int style = 0;
> -  int perm = subtract ? 1 : 0;
> +  int perm = 0;
>  
> -  /* Check if we have a negate operation, if so absorb the node and continue
> -     looking.  */
> +  /* Determine which style we're looking at.  We only have different ones
> +     whenever a conjugate is involved.  If so absorb the node and continue.  */
>    bool neg0 = vect_match_expression_p (right_op[0], NEGATE_EXPR);
>    bool neg1 = vect_match_expression_p (right_op[1], NEGATE_EXPR);
>  
> -  /* Determine which style we're looking at.  We only have different ones
> -     whenever a conjugate is involved.  */
> -  if (neg0 && neg1)
> -    ;
> -  else if (neg0)
> -    {
> -      right_op[0] = SLP_TREE_CHILDREN (right_op[0])[0];
> -      stats = CONJ_FST;
> -      if (subtract)
> -	perm = 0;
> -    }
> -  else if (neg1)
> +   /* Determine which style we're looking at.  We only have different ones
> +      whenever a conjugate is involved.  */
> +  if (neg0 != neg1 && (neg0 || neg1))
>      {
> -      right_op[1] = SLP_TREE_CHILDREN (right_op[1])[0];
> -      stats = CONJ_SND;
> -      perm = 1;
> +      unsigned idx = !!neg1;
> +      right_op[idx] = SLP_TREE_CHILDREN (right_op[idx])[0];
> +      if (linear_loads_p (perm_cache, left_op[!!!neg1]) == PERM_EVENEVEN)
> +	{
> +	  stats = CONJ_FST;
> +	  style = 1;
> +	  if (subtract && neg0)
> +	    perm = 1;
> +	}
> +      else
> +	{
> +	  stats = CONJ_SND;
> +	  perm = 1;
> +	}
>      }
>  
>    *_status = stats;
> @@ -1069,7 +1069,16 @@ complex_mul_pattern::matches (complex_operation_t op,
>    enum _conj_status status;
>    if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
>  				     right_op, false, &status))
> -    return IFN_LAST;
> +    {
> +	/* Try swapping the operands and trying again.  */
> +	std::swap (left_op[0], left_op[1]);
> +	right_op.truncate (0);
> +	right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
> +	std::swap (right_op[0], right_op[1]);
> +	if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
> +					   right_op, false, &status))
> +	  return IFN_LAST;
> +    }
>  
>    if (status == CONJ_NONE)
>      ifn = IFN_COMPLEX_MUL;
> @@ -1089,7 +1098,7 @@ complex_mul_pattern::matches (complex_operation_t op,
>        ops->quick_push (right_op[1]);
>        ops->quick_push (left_op[0]);
>      }
> -  else if (kind == PERM_EVENEVEN && status != CONJ_SND)
> +  else if (kind == PERM_EVENEVEN && status == CONJ_NONE)
>      {
>        ops->quick_push (left_op[0]);
>        ops->quick_push (right_op[0]);
> @@ -1246,15 +1255,15 @@ complex_fma_pattern::matches (complex_operation_t op,
>  
>    if (ifn == IFN_COMPLEX_FMA)
>      {
> -      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
>        ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
>        ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
> +      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
>      }
>    else
>      {
> -      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
>        ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
>        ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
> +      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
>      }
>  
>    return ifn;
> @@ -1290,8 +1299,8 @@ complex_fma_pattern::build (vec_info *vinfo)
>    SLP_TREE_CHILDREN (*this->m_node).create (3);
>    SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops);
>  
> +  SLP_TREE_REF_COUNT (this->m_ops[0])++;
>    SLP_TREE_REF_COUNT (this->m_ops[1])++;
> -  SLP_TREE_REF_COUNT (this->m_ops[2])++;
>  
>    vect_free_slp_tree (node);
>  
> @@ -1397,24 +1406,32 @@ complex_fms_pattern::matches (complex_operation_t op,
>    if (!vect_pattern_validate_optab (ifn, *ref_node))
>      return IFN_LAST;
>  
> +  child = SLP_TREE_CHILDREN ((*ops)[1])[0];
>    ops->truncate (0);
>    ops->create (4);
>  
>    complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]);
> -  if (kind == PERM_EVENODD)
> +  if (kind == PERM_EVENODD || kind == PERM_TOP)
>      {
>        ops->quick_push (child);
>        ops->quick_push (right_op[0]);
>        ops->quick_push (right_op[1]);
> -      ops->quick_push (left_op[1]);
> +      ops->quick_push (left_op[0]);
>      }
> -  else
> +  else if (status == CONJ_NONE)
>      {
>        ops->quick_push (child);
>        ops->quick_push (right_op[1]);
>        ops->quick_push (right_op[0]);
>        ops->quick_push (left_op[0]);
>      }
> +  else
> +    {
> +      ops->quick_push (child);
> +      ops->quick_push (right_op[1]);
> +      ops->quick_push (right_op[0]);
> +      ops->quick_push (left_op[1]);
> +    }
>  
>    return ifn;
>  }
> 
> 
>
  

Patch

diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index a3bd90ff85b4ca5423a94388d480b66051a83e08..8b08a0f33dd1cd23ad9577243524c1feaa5e8ed9 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -873,10 +873,8 @@  compatible_complex_nodes_p (slp_compat_nodes_map_t *compat_cache,
 static inline bool
 vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
 			      slp_compat_nodes_map_t *compat_cache,
-			      vec<slp_tree> &left_op,
-			      vec<slp_tree> &right_op,
-			      bool subtract,
-			      enum _conj_status *_status)
+			      vec<slp_tree> &left_op, vec<slp_tree> &right_op,
+			      bool subtract, enum _conj_status *_status)
 {
   auto_vec<slp_tree> ops;
   enum _conj_status stats = CONJ_NONE;
@@ -902,29 +900,31 @@  vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
 
   /* Default to style and perm 0, most operations use this one.  */
   int style = 0;
-  int perm = subtract ? 1 : 0;
+  int perm = 0;
 
-  /* Check if we have a negate operation, if so absorb the node and continue
-     looking.  */
+  /* Determine which style we're looking at.  We only have different ones
+     whenever a conjugate is involved.  If so absorb the node and continue.  */
   bool neg0 = vect_match_expression_p (right_op[0], NEGATE_EXPR);
   bool neg1 = vect_match_expression_p (right_op[1], NEGATE_EXPR);
 
-  /* Determine which style we're looking at.  We only have different ones
-     whenever a conjugate is involved.  */
-  if (neg0 && neg1)
-    ;
-  else if (neg0)
-    {
-      right_op[0] = SLP_TREE_CHILDREN (right_op[0])[0];
-      stats = CONJ_FST;
-      if (subtract)
-	perm = 0;
-    }
-  else if (neg1)
+   /* Determine which style we're looking at.  We only have different ones
+      whenever a conjugate is involved.  */
+  if (neg0 != neg1 && (neg0 || neg1))
     {
-      right_op[1] = SLP_TREE_CHILDREN (right_op[1])[0];
-      stats = CONJ_SND;
-      perm = 1;
+      unsigned idx = !!neg1;
+      right_op[idx] = SLP_TREE_CHILDREN (right_op[idx])[0];
+      if (linear_loads_p (perm_cache, left_op[!!!neg1]) == PERM_EVENEVEN)
+	{
+	  stats = CONJ_FST;
+	  style = 1;
+	  if (subtract && neg0)
+	    perm = 1;
+	}
+      else
+	{
+	  stats = CONJ_SND;
+	  perm = 1;
+	}
     }
 
   *_status = stats;
@@ -1069,7 +1069,16 @@  complex_mul_pattern::matches (complex_operation_t op,
   enum _conj_status status;
   if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
 				     right_op, false, &status))
-    return IFN_LAST;
+    {
+	/* Try swapping the operands and trying again.  */
+	std::swap (left_op[0], left_op[1]);
+	right_op.truncate (0);
+	right_op.safe_splice (SLP_TREE_CHILDREN (muls[1]));
+	std::swap (right_op[0], right_op[1]);
+	if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
+					   right_op, false, &status))
+	  return IFN_LAST;
+    }
 
   if (status == CONJ_NONE)
     ifn = IFN_COMPLEX_MUL;
@@ -1089,7 +1098,7 @@  complex_mul_pattern::matches (complex_operation_t op,
       ops->quick_push (right_op[1]);
       ops->quick_push (left_op[0]);
     }
-  else if (kind == PERM_EVENEVEN && status != CONJ_SND)
+  else if (kind == PERM_EVENEVEN && status == CONJ_NONE)
     {
       ops->quick_push (left_op[0]);
       ops->quick_push (right_op[0]);
@@ -1246,15 +1255,15 @@  complex_fma_pattern::matches (complex_operation_t op,
 
   if (ifn == IFN_COMPLEX_FMA)
     {
-      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
       ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
       ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
+      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
     }
   else
     {
-      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
       ops->quick_push (SLP_TREE_CHILDREN (node)[0]);
       ops->quick_push (SLP_TREE_CHILDREN (node)[1]);
+      ops->quick_push (SLP_TREE_CHILDREN (vnode)[0]);
     }
 
   return ifn;
@@ -1290,8 +1299,8 @@  complex_fma_pattern::build (vec_info *vinfo)
   SLP_TREE_CHILDREN (*this->m_node).create (3);
   SLP_TREE_CHILDREN (*this->m_node).safe_splice (this->m_ops);
 
+  SLP_TREE_REF_COUNT (this->m_ops[0])++;
   SLP_TREE_REF_COUNT (this->m_ops[1])++;
-  SLP_TREE_REF_COUNT (this->m_ops[2])++;
 
   vect_free_slp_tree (node);
 
@@ -1397,24 +1406,32 @@  complex_fms_pattern::matches (complex_operation_t op,
   if (!vect_pattern_validate_optab (ifn, *ref_node))
     return IFN_LAST;
 
+  child = SLP_TREE_CHILDREN ((*ops)[1])[0];
   ops->truncate (0);
   ops->create (4);
 
   complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]);
-  if (kind == PERM_EVENODD)
+  if (kind == PERM_EVENODD || kind == PERM_TOP)
     {
       ops->quick_push (child);
       ops->quick_push (right_op[0]);
       ops->quick_push (right_op[1]);
-      ops->quick_push (left_op[1]);
+      ops->quick_push (left_op[0]);
     }
-  else
+  else if (status == CONJ_NONE)
     {
       ops->quick_push (child);
       ops->quick_push (right_op[1]);
       ops->quick_push (right_op[0]);
       ops->quick_push (left_op[0]);
     }
+  else
+    {
+      ops->quick_push (child);
+      ops->quick_push (right_op[1]);
+      ops->quick_push (right_op[0]);
+      ops->quick_push (left_op[1]);
+    }
 
   return ifn;
 }