[2/5,_Hashtable] New method to check current bucket

Message ID 82bd8c6e-760f-a6c8-2e4a-fad412a0ce2c@gmail.com
State New
Headers
Series [1/5,_Hashtable] Make more use of user provided hint |

Commit Message

François Dumont June 20, 2022, 4:57 p.m. UTC
  libstdc++: [_Hashtable] Use next bucket node and equal_to to check if 
same bucket

To find out if we are still in the same bucket we can first check that 
current node
is not the next bucket's before-begin and then that hash code are equals 
when cached.
If not we can also use the equal_to functor in a multi-container 
context. As a last
resort, compute node bucket index.

libstdc++-v3/ChangeLog:

     * include/bits/hashtable_policy.h 
(_Hashtable_base<>::_S_hash_code_equals): New.
     * include/bits/hashtable.h (_Hashtable<>::_M_is_in_bucket): New, 
use latter.
     (_Hashtable<>::_M_find_before_node): Use latter.
     (_Hashtable<>::_M_find_before_node_tr): Likewise.

Tested under Linux x86_64.

François
  

Patch

diff --git a/libstdc++-v3/include/bits/hashtable.h b/libstdc++-v3/include/bits/hashtable.h
index 8318da168e3..e53cbaf0644 100644
--- a/libstdc++-v3/include/bits/hashtable.h
+++ b/libstdc++-v3/include/bits/hashtable.h
@@ -801,6 +801,33 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
       __node_base_ptr
       _M_find_before_node(const key_type&);
 
+      bool
+      _M_is_in_bucket(size_type __bkt, __node_ptr, __node_ptr __n,
+		      true_type /* __uks */) const
+      { return _M_bucket_index(*__n) == __bkt; }
+
+      bool
+      _M_is_in_bucket(size_type __bkt, __node_ptr __prev_n, __node_ptr __n,
+		      false_type /* __uks */) const
+      {
+	return this->_M_key_equals(_ExtractKey{}(__prev_n->_M_v()), *__n)
+	  || _M_bucket_index(*__n) == __bkt;
+      }
+
+      bool
+      _M_is_nxt_in_bucket(size_type __bkt, __node_ptr __prev_n,
+			  __node_base_ptr __nxt_bkt_n) const
+      {
+	if (__prev_n == __nxt_bkt_n)
+	  return false;
+
+	__node_ptr __n = __prev_n->_M_next();
+	if (this->_S_hash_code_equals(*__prev_n, *__n))
+	  return true;
+
+	return _M_is_in_bucket(__bkt, __prev_n, __n, __unique_keys{});
+      }
+
       // Find and insert helper functions and types
       // Find the node before the one matching the criteria.
       __node_base_ptr
@@ -1999,13 +2026,15 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
       if (!__prev_p)
 	return nullptr;
 
+      __node_base_ptr __nxt_bkt_n
+	= __bkt < _M_bucket_count - 1 ? _M_buckets[__bkt + 1] : nullptr;
       for (__node_ptr __p = static_cast<__node_ptr>(__prev_p->_M_nxt);;
 	   __p = __p->_M_next())
 	{
 	  if (this->_M_equals(__k, __code, *__p))
 	    return __prev_p;
 
-	  if (!__p->_M_nxt || _M_bucket_index(*__p->_M_next()) != __bkt)
+	  if (!__p->_M_nxt || !_M_is_nxt_in_bucket(__bkt, __p, __nxt_bkt_n))
 	    break;
 	  __prev_p = __p;
 	}
@@ -2029,13 +2058,15 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	if (!__prev_p)
 	  return nullptr;
 
+	__node_base_ptr __nxt_bkt_n
+	  = __bkt < _M_bucket_count - 1 ? _M_buckets[__bkt + 1] : nullptr;
 	for (__node_ptr __p = static_cast<__node_ptr>(__prev_p->_M_nxt);;
 	     __p = __p->_M_next())
 	  {
 	    if (this->_M_equals_tr(__k, __code, *__p))
 	      return __prev_p;
 
-	    if (!__p->_M_nxt || _M_bucket_index(*__p->_M_next()) != __bkt)
+	    if (!__p->_M_nxt || !_M_is_nxt_in_bucket(__bkt, __p, __nxt_bkt_n))
 	      break;
 	    __prev_p = __p;
 	  }
diff --git a/libstdc++-v3/include/bits/hashtable_policy.h b/libstdc++-v3/include/bits/hashtable_policy.h
index 83a9ff2bb3d..e848ba1d3f7 100644
--- a/libstdc++-v3/include/bits/hashtable_policy.h
+++ b/libstdc++-v3/include/bits/hashtable_policy.h
@@ -1721,6 +1721,16 @@  namespace __detail
       : __hash_code_base(__hash), _EqualEBO(__eq)
       { }
 
+      static bool
+      _S_hash_code_equals(const _Hash_node_code_cache<false>&,
+			  const _Hash_node_code_cache<false>&)
+      { return false; }
+
+      static bool
+      _S_hash_code_equals(const _Hash_node_code_cache<true>& __lhn,
+			  const _Hash_node_code_cache<true>& __rhn)
+      { return __lhn._M_hash_code == __rhn._M_hash_code; }
+
       bool
       _M_key_equals(const _Key& __k,
 		    const _Hash_node_value<_Value,