RISC-V: Refine Phase 3 of VSETVL PASS

Message ID 20230104134526.206115-1-juzhe.zhong@rivai.ai
State Accepted, archived
Headers
Series RISC-V: Refine Phase 3 of VSETVL PASS |

Commit Message

钟居哲 Jan. 4, 2023, 1:45 p.m. UTC
  From: Ju-Zhe Zhong <juzhe.zhong@rivai.ai>

gcc/ChangeLog:

        * config/riscv/riscv-vsetvl.cc (can_backward_propagate_p): Fix for null iter_bb.
        (vector_insn_info::set_demand_info): New function.
        (pass_vsetvl::emit_local_forward_vsetvls): Adjust for refinement of Phase 3.
        (pass_vsetvl::merge_successors): Ditto.
        (pass_vsetvl::compute_global_backward_infos): Ditto.
        (pass_vsetvl::backward_demand_fusion): Ditto.
        (pass_vsetvl::forward_demand_fusion): Ditto.
        (pass_vsetvl::demand_fusion): New function.
        (pass_vsetvl::lazy_vsetvl): Adjust for refinement of phase 3.
        * config/riscv/riscv-vsetvl.h: New function declaration.

---
 gcc/config/riscv/riscv-vsetvl.cc | 138 ++++++++++++++++++++++++++++---
 gcc/config/riscv/riscv-vsetvl.h  |   1 +
 2 files changed, 128 insertions(+), 11 deletions(-)
  

Comments

Kito Cheng Jan. 26, 2023, 7:14 p.m. UTC | #1
committed with few tweaks, thanks.

On Wed, Jan 4, 2023 at 9:46 PM <juzhe.zhong@rivai.ai> wrote:

> From: Ju-Zhe Zhong <juzhe.zhong@rivai.ai>
>
> gcc/ChangeLog:
>
>         * config/riscv/riscv-vsetvl.cc (can_backward_propagate_p): Fix for
> null iter_bb.
>         (vector_insn_info::set_demand_info): New function.
>         (pass_vsetvl::emit_local_forward_vsetvls): Adjust for refinement
> of Phase 3.
>         (pass_vsetvl::merge_successors): Ditto.
>         (pass_vsetvl::compute_global_backward_infos): Ditto.
>         (pass_vsetvl::backward_demand_fusion): Ditto.
>         (pass_vsetvl::forward_demand_fusion): Ditto.
>         (pass_vsetvl::demand_fusion): New function.
>         (pass_vsetvl::lazy_vsetvl): Adjust for refinement of phase 3.
>         * config/riscv/riscv-vsetvl.h: New function declaration.
>
> ---
>  gcc/config/riscv/riscv-vsetvl.cc | 138 ++++++++++++++++++++++++++++---
>  gcc/config/riscv/riscv-vsetvl.h  |   1 +
>  2 files changed, 128 insertions(+), 11 deletions(-)
>
> diff --git a/gcc/config/riscv/riscv-vsetvl.cc
> b/gcc/config/riscv/riscv-vsetvl.cc
> index 52f0195980a..d42cfa91d63 100644
> --- a/gcc/config/riscv/riscv-vsetvl.cc
> +++ b/gcc/config/riscv/riscv-vsetvl.cc
> @@ -43,7 +43,7 @@ along with GCC; see the file COPYING3.  If not see
>      -  Phase 2 - Emit vsetvl instructions within each basic block
> according to
>         demand, compute and save ANTLOC && AVLOC of each block.
>
> -    -  Phase 3 - Backward demanded info propagation and fusion across
> blocks.
> +    -  Phase 3 - Backward && forward demanded info propagation and fusion
> across blocks.
>
>      -  Phase 4 - Lazy code motion including: compute local properties,
>         pre_edge_lcm and vsetvl insertion && delete edges for LCM results.
> @@ -434,8 +434,12 @@ can_backward_propagate_p (const function_info *ssa,
> const basic_block cfg_bb,
>         set_info *ultimate_def = look_through_degenerate_phi (set);
>         const basic_block ultimate_bb = ultimate_def->bb ()->cfg_bb ();
>         FOR_BB_BETWEEN (iter_bb, ultimate_bb, def->bb ()->cfg_bb (),
> next_bb)
> -         if (iter_bb->index == cfg_bb->index)
> -           return true;
> +         {
> +           if (!iter_bb)
> +             break;
> +           if (iter_bb->index == cfg_bb->index)
> +             return true;
> +         }
>
>         return false;
>        };
> @@ -1172,6 +1176,19 @@ vector_insn_info::parse_insn (insn_info *insn)
>      m_demands[DEMAND_MASK_POLICY] = true;
>  }
>
> +void
> +vector_insn_info::set_demand_info (const vector_insn_info &other)
> +{
> +  set_sew (other.get_sew ());
> +  set_vlmul (other.get_vlmul ());
> +  set_ratio (other.get_ratio ());
> +  set_ta (other.get_ta ());
> +  set_ma (other.get_ma ());
> +  set_avl_info (other.get_avl_info ());
> +  for (size_t i = 0; i < NUM_DEMAND; i++)
> +    m_demands[i] = other.demand_p ((enum demand_type) i);
> +}
> +
>  void
>  vector_insn_info::demand_vl_vtype ()
>  {
> @@ -1691,8 +1708,10 @@ private:
>    void emit_local_forward_vsetvls (const bb_info *);
>
>    /* Phase 3.  */
> -  void merge_successors (const basic_block, const basic_block);
> -  void compute_global_backward_infos (void);
> +  bool merge_successors (const basic_block, const basic_block);
> +  bool backward_demand_fusion (void);
> +  bool forward_demand_fusion (void);
> +  void demand_fusion (void);
>
>    /* Phase 4.  */
>    void prune_expressions (void);
> @@ -1866,7 +1885,7 @@ pass_vsetvl::emit_local_forward_vsetvls (const
> bb_info *bb)
>  }
>
>  /* Merge all successors of Father except child node.  */
> -void
> +bool
>  pass_vsetvl::merge_successors (const basic_block father,
>                                const basic_block child)
>  {
> @@ -1877,7 +1896,8 @@ pass_vsetvl::merge_successors (const basic_block
> father,
>               || father_info.local_dem.empty_p ());
>    gcc_assert (father_info.reaching_out.dirty_p ()
>               || father_info.reaching_out.empty_p ());
> -
> +
> +  bool changed_p = false;
>    FOR_EACH_EDGE (e, ei, father->succs)
>      {
>        const basic_block succ = e->dest;
> @@ -1907,12 +1927,15 @@ pass_vsetvl::merge_successors (const basic_block
> father,
>
>        father_info.local_dem = new_info;
>        father_info.reaching_out = new_info;
> +      changed_p = true;
>      }
> +
> +  return changed_p;
>  }
>
>  /* Compute global backward demanded info.  */
> -void
> -pass_vsetvl::compute_global_backward_infos (void)
> +bool
> +pass_vsetvl::backward_demand_fusion (void)
>  {
>    /* We compute global infos by backward propagation.
>       We want to have better performance in these following cases:
> @@ -1939,6 +1962,7 @@ pass_vsetvl::compute_global_backward_infos (void)
>            We backward propagate the first VSETVL into e32,mf2 so that we
>            could be able to eliminate the second VSETVL in LCM.  */
>
> +  bool changed_p = false;
>    for (const bb_info *bb : crtl->ssa->reverse_bbs ())
>      {
>        basic_block cfg_bb = bb->cfg_bb ();
> @@ -1982,9 +2006,10 @@ pass_vsetvl::compute_global_backward_infos (void)
>                   block_info.reaching_out.set_dirty ();
>                   block_info.reaching_out.set_dirty_pat (new_pat);
>                   block_info.local_dem = block_info.reaching_out;
> +                 changed_p = true;
>                 }
>
> -             merge_successors (e->src, cfg_bb);
> +             changed_p |= merge_successors (e->src, cfg_bb);
>             }
>           else if (block_info.reaching_out.dirty_p ())
>             {
> @@ -2011,6 +2036,7 @@ pass_vsetvl::compute_global_backward_infos (void)
>               new_info.set_dirty_pat (new_pat);
>               block_info.local_dem = new_info;
>               block_info.reaching_out = new_info;
> +             changed_p = true;
>             }
>           else
>             {
> @@ -2031,9 +2057,99 @@ pass_vsetvl::compute_global_backward_infos (void)
>               if (block_info.local_dem == block_info.reaching_out)
>                 block_info.local_dem = new_info;
>               block_info.reaching_out = new_info;
> +             changed_p = true;
> +           }
> +       }
> +    }
> +  return changed_p;
> +}
> +
> +/* Compute global forward demanded info.  */
> +bool
> +pass_vsetvl::forward_demand_fusion (void)
> +{
> +  /* Enhance the global information propagation especially
> +     backward propagation miss the propagation.
> +     Consider such case:
> +
> +                       bb0
> +                       (TU)
> +                      /   \
> +                    bb1   bb2
> +                    (TU)  (ANY)
> +  existing edge -----> \    / (TU) <----- LCM create this edge.
> +                       bb3
> +                       (TU)
> +
> +     Base on the situation, LCM fails to eliminate the VSETVL instruction
> and
> +     insert an edge from bb2 to bb3 since we can't backward propagate bb3
> into
> +     bb2. To avoid this confusing LCM result and non-optimal codegen, we
> should
> +     forward propagate information from bb0 to bb2 which is friendly to
> LCM.  */
> +  bool changed_p = false;
> +  for (const bb_info *bb : crtl->ssa->bbs ())
> +    {
> +      basic_block cfg_bb = bb->cfg_bb ();
> +      const auto &prop
> +       = m_vector_manager->vector_block_infos[cfg_bb->index].reaching_out;
> +
> +      /* If there is nothing to propagate, just skip it.  */
> +      if (!prop.valid_or_dirty_p ())
> +       continue;
> +
> +      edge e;
> +      edge_iterator ei;
> +      /* Forward propagate to each successor.  */
> +      FOR_EACH_EDGE (e, ei, cfg_bb->succs)
> +       {
> +         auto &local_dem
> +           =
> m_vector_manager->vector_block_infos[e->dest->index].local_dem;
> +         auto &reaching_out
> +           =
> m_vector_manager->vector_block_infos[e->dest->index].reaching_out;
> +
> +         /* It's quite obvious, we don't need to propagate itself.  */
> +         if (e->dest->index == cfg_bb->index)
> +           continue;
> +
> +         /* If there is nothing to propagate, just skip it.  */
> +         if (!local_dem.valid_or_dirty_p ())
> +           continue;
> +
> +         if (prop > local_dem)
> +           {
> +             if (local_dem.dirty_p ())
> +               {
> +                 gcc_assert (local_dem == reaching_out);
> +                 rtx dirty_pat
> +                   = gen_vsetvl_pat (prop.get_insn ()->rtl (), prop);
> +                 local_dem = prop;
> +                 local_dem.set_dirty ();
> +                 local_dem.set_dirty_pat (dirty_pat);
> +                 reaching_out = local_dem;
> +               }
> +             else
> +               {
> +                 if (reaching_out == local_dem)
> +                   reaching_out.set_demand_info (prop);
> +                 local_dem.set_demand_info (prop);
> +                 change_vsetvl_insn (local_dem.get_insn (), prop);
> +               }
> +             changed_p = true;
>             }
>         }
>      }
> +  return changed_p;
> +}
> +
> +void
> +pass_vsetvl::demand_fusion (void)
> +{
> +  bool changed_p = true;
> +  while (changed_p)
> +    {
> +      changed_p = false;
> +      changed_p |= backward_demand_fusion ();
> +      changed_p |= forward_demand_fusion ();
> +    }
>
>    if (dump_file)
>      {
> @@ -2519,7 +2635,7 @@ pass_vsetvl::lazy_vsetvl (void)
>    /* Phase 3 - Propagate demanded info across blocks.  */
>    if (dump_file)
>      fprintf (dump_file, "\nPhase 3: Demands propagation across blocks\n");
> -  compute_global_backward_infos ();
> +  demand_fusion ();
>    if (dump_file)
>      m_vector_manager->dump (dump_file);
>
> diff --git a/gcc/config/riscv/riscv-vsetvl.h
> b/gcc/config/riscv/riscv-vsetvl.h
> index c8218a6ff00..33481a87163 100644
> --- a/gcc/config/riscv/riscv-vsetvl.h
> +++ b/gcc/config/riscv/riscv-vsetvl.h
> @@ -273,6 +273,7 @@ public:
>    void set_dirty () { m_state = DIRTY; }
>    void set_dirty_pat (rtx pat) { m_dirty_pat = pat; }
>    void set_insn (rtl_ssa::insn_info *insn) { m_insn = insn; }
> +  void set_demand_info (const vector_insn_info &);
>
>    bool demand_p (enum demand_type type) const { return m_demands[type]; }
>    void demand (enum demand_type type) { m_demands[type] = true; }
> --
> 2.36.3
>
>
  

Patch

diff --git a/gcc/config/riscv/riscv-vsetvl.cc b/gcc/config/riscv/riscv-vsetvl.cc
index 52f0195980a..d42cfa91d63 100644
--- a/gcc/config/riscv/riscv-vsetvl.cc
+++ b/gcc/config/riscv/riscv-vsetvl.cc
@@ -43,7 +43,7 @@  along with GCC; see the file COPYING3.  If not see
     -  Phase 2 - Emit vsetvl instructions within each basic block according to
        demand, compute and save ANTLOC && AVLOC of each block.
 
-    -  Phase 3 - Backward demanded info propagation and fusion across blocks.
+    -  Phase 3 - Backward && forward demanded info propagation and fusion across blocks.
 
     -  Phase 4 - Lazy code motion including: compute local properties,
        pre_edge_lcm and vsetvl insertion && delete edges for LCM results.
@@ -434,8 +434,12 @@  can_backward_propagate_p (const function_info *ssa, const basic_block cfg_bb,
 	set_info *ultimate_def = look_through_degenerate_phi (set);
 	const basic_block ultimate_bb = ultimate_def->bb ()->cfg_bb ();
 	FOR_BB_BETWEEN (iter_bb, ultimate_bb, def->bb ()->cfg_bb (), next_bb)
-	  if (iter_bb->index == cfg_bb->index)
-	    return true;
+	  {
+	    if (!iter_bb)
+	      break;
+	    if (iter_bb->index == cfg_bb->index)
+	      return true;
+	  }
 
 	return false;
       };
@@ -1172,6 +1176,19 @@  vector_insn_info::parse_insn (insn_info *insn)
     m_demands[DEMAND_MASK_POLICY] = true;
 }
 
+void
+vector_insn_info::set_demand_info (const vector_insn_info &other)
+{
+  set_sew (other.get_sew ());
+  set_vlmul (other.get_vlmul ());
+  set_ratio (other.get_ratio ());
+  set_ta (other.get_ta ());
+  set_ma (other.get_ma ());
+  set_avl_info (other.get_avl_info ());
+  for (size_t i = 0; i < NUM_DEMAND; i++)
+    m_demands[i] = other.demand_p ((enum demand_type) i);
+}
+
 void
 vector_insn_info::demand_vl_vtype ()
 {
@@ -1691,8 +1708,10 @@  private:
   void emit_local_forward_vsetvls (const bb_info *);
 
   /* Phase 3.  */
-  void merge_successors (const basic_block, const basic_block);
-  void compute_global_backward_infos (void);
+  bool merge_successors (const basic_block, const basic_block);
+  bool backward_demand_fusion (void);
+  bool forward_demand_fusion (void);
+  void demand_fusion (void);
 
   /* Phase 4.  */
   void prune_expressions (void);
@@ -1866,7 +1885,7 @@  pass_vsetvl::emit_local_forward_vsetvls (const bb_info *bb)
 }
 
 /* Merge all successors of Father except child node.  */
-void
+bool
 pass_vsetvl::merge_successors (const basic_block father,
 			       const basic_block child)
 {
@@ -1877,7 +1896,8 @@  pass_vsetvl::merge_successors (const basic_block father,
 	      || father_info.local_dem.empty_p ());
   gcc_assert (father_info.reaching_out.dirty_p ()
 	      || father_info.reaching_out.empty_p ());
-
+  
+  bool changed_p = false;
   FOR_EACH_EDGE (e, ei, father->succs)
     {
       const basic_block succ = e->dest;
@@ -1907,12 +1927,15 @@  pass_vsetvl::merge_successors (const basic_block father,
 
       father_info.local_dem = new_info;
       father_info.reaching_out = new_info;
+      changed_p = true;
     }
+
+  return changed_p;
 }
 
 /* Compute global backward demanded info.  */
-void
-pass_vsetvl::compute_global_backward_infos (void)
+bool
+pass_vsetvl::backward_demand_fusion (void)
 {
   /* We compute global infos by backward propagation.
      We want to have better performance in these following cases:
@@ -1939,6 +1962,7 @@  pass_vsetvl::compute_global_backward_infos (void)
 	   We backward propagate the first VSETVL into e32,mf2 so that we
 	   could be able to eliminate the second VSETVL in LCM.  */
 
+  bool changed_p = false;
   for (const bb_info *bb : crtl->ssa->reverse_bbs ())
     {
       basic_block cfg_bb = bb->cfg_bb ();
@@ -1982,9 +2006,10 @@  pass_vsetvl::compute_global_backward_infos (void)
 		  block_info.reaching_out.set_dirty ();
 		  block_info.reaching_out.set_dirty_pat (new_pat);
 		  block_info.local_dem = block_info.reaching_out;
+		  changed_p = true;
 		}
 
-	      merge_successors (e->src, cfg_bb);
+	      changed_p |= merge_successors (e->src, cfg_bb);
 	    }
 	  else if (block_info.reaching_out.dirty_p ())
 	    {
@@ -2011,6 +2036,7 @@  pass_vsetvl::compute_global_backward_infos (void)
 	      new_info.set_dirty_pat (new_pat);
 	      block_info.local_dem = new_info;
 	      block_info.reaching_out = new_info;
+	      changed_p = true;
 	    }
 	  else
 	    {
@@ -2031,9 +2057,99 @@  pass_vsetvl::compute_global_backward_infos (void)
 	      if (block_info.local_dem == block_info.reaching_out)
 		block_info.local_dem = new_info;
 	      block_info.reaching_out = new_info;
+	      changed_p = true;
+	    }
+	}
+    }
+  return changed_p;
+}
+
+/* Compute global forward demanded info.  */
+bool
+pass_vsetvl::forward_demand_fusion (void)
+{
+  /* Enhance the global information propagation especially
+     backward propagation miss the propagation.
+     Consider such case:
+
+			bb0
+			(TU)
+		       /   \
+		     bb1   bb2
+		     (TU)  (ANY)
+  existing edge -----> \    / (TU) <----- LCM create this edge.
+			bb3
+			(TU)
+
+     Base on the situation, LCM fails to eliminate the VSETVL instruction and
+     insert an edge from bb2 to bb3 since we can't backward propagate bb3 into
+     bb2. To avoid this confusing LCM result and non-optimal codegen, we should
+     forward propagate information from bb0 to bb2 which is friendly to LCM.  */
+  bool changed_p = false;
+  for (const bb_info *bb : crtl->ssa->bbs ())
+    {
+      basic_block cfg_bb = bb->cfg_bb ();
+      const auto &prop
+	= m_vector_manager->vector_block_infos[cfg_bb->index].reaching_out;
+
+      /* If there is nothing to propagate, just skip it.  */
+      if (!prop.valid_or_dirty_p ())
+	continue;
+
+      edge e;
+      edge_iterator ei;
+      /* Forward propagate to each successor.  */
+      FOR_EACH_EDGE (e, ei, cfg_bb->succs)
+	{
+	  auto &local_dem
+	    = m_vector_manager->vector_block_infos[e->dest->index].local_dem;
+	  auto &reaching_out
+	    = m_vector_manager->vector_block_infos[e->dest->index].reaching_out;
+
+	  /* It's quite obvious, we don't need to propagate itself.  */
+	  if (e->dest->index == cfg_bb->index)
+	    continue;
+
+	  /* If there is nothing to propagate, just skip it.  */
+	  if (!local_dem.valid_or_dirty_p ())
+	    continue;
+
+	  if (prop > local_dem)
+	    {
+	      if (local_dem.dirty_p ())
+		{
+		  gcc_assert (local_dem == reaching_out);
+		  rtx dirty_pat
+		    = gen_vsetvl_pat (prop.get_insn ()->rtl (), prop);
+		  local_dem = prop;
+		  local_dem.set_dirty ();
+		  local_dem.set_dirty_pat (dirty_pat);
+		  reaching_out = local_dem;
+		}
+	      else
+		{
+		  if (reaching_out == local_dem)
+		    reaching_out.set_demand_info (prop);
+		  local_dem.set_demand_info (prop);
+		  change_vsetvl_insn (local_dem.get_insn (), prop);
+		}
+	      changed_p = true;
 	    }
 	}
     }
+  return changed_p;
+}
+
+void
+pass_vsetvl::demand_fusion (void)
+{
+  bool changed_p = true;
+  while (changed_p)
+    {
+      changed_p = false;
+      changed_p |= backward_demand_fusion ();
+      changed_p |= forward_demand_fusion ();
+    }
 
   if (dump_file)
     {
@@ -2519,7 +2635,7 @@  pass_vsetvl::lazy_vsetvl (void)
   /* Phase 3 - Propagate demanded info across blocks.  */
   if (dump_file)
     fprintf (dump_file, "\nPhase 3: Demands propagation across blocks\n");
-  compute_global_backward_infos ();
+  demand_fusion ();
   if (dump_file)
     m_vector_manager->dump (dump_file);
 
diff --git a/gcc/config/riscv/riscv-vsetvl.h b/gcc/config/riscv/riscv-vsetvl.h
index c8218a6ff00..33481a87163 100644
--- a/gcc/config/riscv/riscv-vsetvl.h
+++ b/gcc/config/riscv/riscv-vsetvl.h
@@ -273,6 +273,7 @@  public:
   void set_dirty () { m_state = DIRTY; }
   void set_dirty_pat (rtx pat) { m_dirty_pat = pat; }
   void set_insn (rtl_ssa::insn_info *insn) { m_insn = insn; }
+  void set_demand_info (const vector_insn_info &);
 
   bool demand_p (enum demand_type type) const { return m_demands[type]; }
   void demand (enum demand_type type) { m_demands[type] = true; }