[04/10] libstdc++: Various fixes for atomic wait/notify code

Message ID 20250110212810.832494-5-jwakely@redhat.com
State New
Headers
Series C++20 atomic wait/notify ABI stabilization |

Checks

Context Check Description
linaro-tcwg-bot/tcwg_gcc_build--master-aarch64 success Build passed
linaro-tcwg-bot/tcwg_gcc_build--master-arm success Build passed
linaro-tcwg-bot/tcwg_gcc_check--master-aarch64 fail Test failed

Commit Message

Jonathan Wakely Jan. 10, 2025, 9:23 p.m. UTC
  Remove the unnecessary __wait_args_base base class, which seems to serve
no purpose. Everywhere it's used we actually want __wait_args instead.

The code using the __wait_flags bitmask type is broken, because the
__spin_only constant includes the __do_spin element. This means that
testing (__args & __wait_flags::__spin_only) will be inadvertently true
when only __do_spin is set. This causes the __wait_until_impl function
to never actually wait on the futex (or condition variable), turning all
uses of that function into expensive busy spins. Change __spin_only to
be a single bit (i.e. a bitmask element) and adjust the places where
that bit is set so that they also use the __do_spin element.

Update the __args._M_old value when looping in __atomic_wait_address, so
that the next wait doesn't fail spuriously.

With the new __atomic_wait_address logic, the value function needs to
return the correct type, not just a bool. Without that change, the
boolean value returned by the value function is used as the value
passed to the futex wait, but that mean we're comparing (_M_a == 0) to
_M_a and so can block on the futex when we shouldn't, and then never
wake up.

libstdc++-v3/ChangeLog:

	* include/bits/atomic_timed_wait.h (__cond_wait_impl): Add
	missing inline keyword.
	(__spin_until_impl): Change parameter from __wait_args_base*
	to __wait_args*. Replace make_pair with list-initialization.
	Initialize variable for return value.
	(__wait_until_impl): Likewise. Remove some preprocessor
	conditional logic. Use _M_tracker() for contention tracking.
	Avoid unnecessary const_cast.
	(__wait_until): Replace make_pair with list-initialization.
	(__wait_for):  Change parameter from __wait_args_base* to
	__wait_args*. Add __do_spin flag to args.
	* include/bits/atomic_wait.h (__wait_flags): Do not set the
	__do_spin flag in the __spin_only enumerator. Comment out the
	unused __abi_version_mask enumerator. Define operator| and
	operator|= overloads.
	(__wait_args_base): Remove.
	(__wait_args): Move __wait_args_base members here and remove the
	base class. Constrain constructor to avoid accidents.
	(__wait_args::_M_tracker): Helper function to create an RAII
	object to track contention.
	(__spin_impl): Change parameter from __wait_args_base* to
	__wait_args*. Replace make_pair call with list-initialization.
	(__wait_impl): Likewise.  Remove some preprocessor conditional
	logic.  Always store old value in __args._M_old. Avoid
	unnecessary const_cast. Use _M_tracker.
	(__notify_impl): Change parameter to __wait_args*. Remove some
	preprocessor conditional logic.
	(__atomic_wait_address): Add comment. Update __args._M_old on
	each iteration.
	(__atomic_wait_address_v): Add comment.
	* include/std/latch (latch::wait): Adjust predicates for new
	logic.
	* testsuite/29_atomics/atomic_integral/wait_notify.cc: Improve
	test.
---
 libstdc++-v3/include/bits/atomic_timed_wait.h |  73 ++----
 libstdc++-v3/include/bits/atomic_wait.h       | 221 +++++++++---------
 libstdc++-v3/include/std/latch                |   8 +-
 .../29_atomics/atomic_integral/wait_notify.cc |   4 +
 4 files changed, 147 insertions(+), 159 deletions(-)
  

Patch

diff --git a/libstdc++-v3/include/bits/atomic_timed_wait.h b/libstdc++-v3/include/bits/atomic_timed_wait.h
index 73acea939504..7e2017f2f515 100644
--- a/libstdc++-v3/include/bits/atomic_timed_wait.h
+++ b/libstdc++-v3/include/bits/atomic_timed_wait.h
@@ -111,7 +111,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 
 #ifdef _GLIBCXX_HAS_GTHREADS
     // Returns true if wait ended before timeout.
-    bool
+    inline bool
     __cond_wait_until(__condvar& __cv, mutex& __mx,
 		      const __wait_clock_t::time_point& __atime)
     {
@@ -136,14 +136,14 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 
     inline __wait_result_type
     __spin_until_impl(const __platform_wait_t* __addr,
-		      const __wait_args_base* __a,
+		      const __wait_args* __a,
 		      const __wait_clock_t::time_point& __deadline)
     {
       __wait_args __args{ *__a };
       auto __t0 = __wait_clock_t::now();
       using namespace literals::chrono_literals;
 
-      __platform_wait_t __val;
+      __platform_wait_t __val{};
       auto __now = __wait_clock_t::now();
       for (; __now < __deadline; __now = __wait_clock_t::now())
 	{
@@ -157,44 +157,32 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 #endif
 	  if (__elapsed > 4us)
 	    __thread_yield();
-	  else
-	    {
-	      auto __res = __detail::__spin_impl(__addr, __a);
-	      if (__res.first)
-		return __res;
-	    }
+	  else if (auto __res = __detail::__spin_impl(__addr, __a); __res.first)
+	    return __res;
 
 	  __atomic_load(__addr, &__val, __args._M_order);
 	  if (__val != __args._M_old)
-	      return make_pair(true, __val);
+	    return { true, __val };
 	}
-      return make_pair(false, __val);
+      return { false, __val };
     }
 
     inline __wait_result_type
     __wait_until_impl(const __platform_wait_t* __addr,
-		      const __wait_args_base* __a,
+		      const __wait_args* __a,
 		      const __wait_clock_t::time_point& __atime)
     {
       __wait_args __args{ *__a };
-#ifdef _GLIBCXX_HAVE_PLATFORM_TIMED_WAIT
       __waiter_pool_impl* __pool = nullptr;
-#else
-      // if we don't have __platform_wait, we always need the side-table
-      __waiter_pool_impl* __pool = &__waiter_pool_impl::_S_impl_for(__addr);
-#endif
-
-      __platform_wait_t* __wait_addr;
+      const __platform_wait_t* __wait_addr;
       if (__args & __wait_flags::__proxy_wait)
 	{
-#ifdef _GLIBCXX_HAVE_PLATFORM_TIMED_WAIT
 	  __pool = &__waiter_pool_impl::_S_impl_for(__addr);
-#endif
 	  __wait_addr = &__pool->_M_ver;
 	  __atomic_load(__wait_addr, &__args._M_old, __args._M_order);
 	}
       else
-	__wait_addr = const_cast<__platform_wait_t*>(__addr);
+	__wait_addr = __addr;
 
       if (__args & __wait_flags::__do_spin)
 	{
@@ -205,43 +193,31 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	    return __res;
 	}
 
-      if (!(__args & __wait_flags::__track_contention))
-	{
-	  // caller does not externally track contention
-#ifdef _GLIBCXX_HAVE_PLATFORM_TIMED_WAIT
-	  __pool = (__pool == nullptr) ? &__waiter_pool_impl::_S_impl_for(__addr)
-				       : __pool;
-#endif
-	  __pool->_M_enter_wait();
-	}
+      auto __tracker = __args._M_tracker(__pool, __addr);
 
-      __wait_result_type __res;
 #ifdef _GLIBCXX_HAVE_PLATFORM_TIMED_WAIT
       if (__platform_wait_until(__wait_addr, __args._M_old, __atime))
-	__res = make_pair(true, __args._M_old);
+	return { true, __args._M_old };
       else
-	__res = make_pair(false, __args._M_old);
+	return { false, __args._M_old };
 #else
-      __platform_wait_t __val;
+      __platform_wait_t __val{};
       __atomic_load(__wait_addr, &__val, __args._M_order);
       if (__val == __args._M_old)
 	{
+	  if (!__pool)
+	    __pool = &__waiter_pool_impl::_S_impl_for(__addr);
 	  lock_guard<mutex> __l{ __pool->_M_mtx };
 	  __atomic_load(__wait_addr, &__val, __args._M_order);
-	  if (__val == __args._M_old &&
-	      __cond_wait_until(__pool->_M_cv, __pool->_M_mtx, __atime))
-	    __res = make_pair(true, __val);
+	  if (__val == __args._M_old
+		&& __cond_wait_until(__pool->_M_cv, __pool->_M_mtx, __atime))
+	    return { true, __val };
 	}
-      else
-	__res = make_pair(false, __val);
+      return { false, __val };
 #endif
-
-      if (!(__args & __wait_flags::__track_contention))
-	// caller does not externally track contention
-	__pool->_M_leave_wait();
-      return __res;
     }
 
+    // Returns {true, val} if wait ended before a timeout.
     template<typename _Clock, typename _Dur>
       __wait_result_type
       __wait_until(const __platform_wait_t* __addr, const __wait_args* __args,
@@ -259,22 +235,23 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 		// we need to check against the caller-supplied clock
 		// to tell whether we should return a timeout.
 		if (_Clock::now() < __atime)
-		  return make_pair(true, __res.second);
+		  __res.first = true;
 	      }
 	    return __res;
 	  }
       }
 
+    // Returns {true, val} if wait ended before a timeout.
     template<typename _Rep, typename _Period>
       __wait_result_type
-      __wait_for(const __platform_wait_t* __addr, const __wait_args_base* __a,
+      __wait_for(const __platform_wait_t* __addr, const __wait_args* __a,
 		 const chrono::duration<_Rep, _Period>& __rtime) noexcept
       {
 	__wait_args __args{ *__a };
 	if (!__rtime.count())
 	  {
 	    // no rtime supplied, just spin a bit
-	    __args |= __wait_flags::__spin_only;
+	    __args |= __wait_flags::__do_spin | __wait_flags::__spin_only;
 	    return __detail::__wait_impl(__addr, &__args);
 	  }
 
diff --git a/libstdc++-v3/include/bits/atomic_wait.h b/libstdc++-v3/include/bits/atomic_wait.h
index ebab4b099e66..29b83cad6e6c 100644
--- a/libstdc++-v3/include/bits/atomic_wait.h
+++ b/libstdc++-v3/include/bits/atomic_wait.h
@@ -177,6 +177,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 #ifndef _GLIBCXX_HAVE_PLATFORM_WAIT
       __condvar _M_cv;
 #endif
+
       __waiter_pool_impl() = default;
 
       void
@@ -211,88 +212,104 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
        __proxy_wait = 1,
        __track_contention = 2,
        __do_spin = 4,
-       __spin_only = 8 | __do_spin, // implies __do_spin
-       __abi_version_mask = 0xffff0000,
+       __spin_only = 8, // Ignored unless __do_spin is also set.
+       // __abi_version_mask = 0xffff0000,
     };
 
-    struct __wait_args_base
+    [[__gnu__::__always_inline__]]
+    constexpr __wait_flags
+    operator|(__wait_flags __l, __wait_flags __r) noexcept
+    {
+      using _Ut = underlying_type_t<__wait_flags>;
+      return static_cast<__wait_flags>(static_cast<_Ut>(__l)
+					 | static_cast<_Ut>(__r));
+    }
+
+    [[__gnu__::__always_inline__]]
+    constexpr __wait_flags&
+    operator|=(__wait_flags& __l, __wait_flags __r) noexcept
+    { return __l = __l | __r; }
+
+    struct __wait_args
     {
       __wait_flags _M_flags;
       int _M_order = __ATOMIC_ACQUIRE;
       __platform_wait_t _M_old = 0;
-    };
 
-    struct __wait_args : __wait_args_base
-    {
-      template<typename _Tp>
-	explicit __wait_args(const _Tp* __addr,
-			     bool __bare_wait = false) noexcept
-	    : __wait_args_base{ _S_flags_for(__addr, __bare_wait) }
+      template<typename _Tp> requires (!is_same_v<_Tp, __wait_args>)
+	explicit
+	__wait_args(const _Tp* __addr, bool __bare_wait = false) noexcept
+	: _M_flags{ _S_flags_for(__addr, __bare_wait) }
 	{ }
 
       __wait_args(const __platform_wait_t* __addr, __platform_wait_t __old,
 		  int __order, bool __bare_wait = false) noexcept
-	  : __wait_args_base{ _S_flags_for(__addr, __bare_wait), __order, __old }
-	{ }
-
-      explicit __wait_args(const __wait_args_base& __base)
-	  : __wait_args_base{ __base }
-	{ }
+      : _M_flags{ _S_flags_for(__addr, __bare_wait)},
+	_M_order{__order}, _M_old{__old}
+      { }
 
       __wait_args(const __wait_args&) noexcept = default;
-      __wait_args&
-      operator=(const __wait_args&) noexcept = default;
+      __wait_args& operator=(const __wait_args&) noexcept = default;
 
+      // Test whether _M_flags & __flag is non-zero.
       bool
-      operator&(__wait_flags __flag) const noexcept
+      operator&(__wait_flags __flags) const noexcept
       {
-	 using __t = underlying_type_t<__wait_flags>;
-	 return static_cast<__t>(_M_flags)
-	     & static_cast<__t>(__flag);
-      }
-
-      __wait_args
-      operator|(__wait_flags __flag) const noexcept
-      {
-	using __t = underlying_type_t<__wait_flags>;
-	__wait_args __res{ *this };
-	const auto __flags = static_cast<__t>(__res._M_flags)
-			     | static_cast<__t>(__flag);
-	__res._M_flags = __wait_flags{ __flags };
-	return __res;
+	 using _Ut = underlying_type_t<__wait_flags>;
+	 return static_cast<_Ut>(_M_flags) & static_cast<_Ut>(__flags);
       }
 
+      // Set __flags in _M_flags.
       __wait_args&
-      operator|=(__wait_flags __flag) noexcept
+      operator|=(__wait_flags __flags) noexcept
       {
-	using __t = underlying_type_t<__wait_flags>;
-	const auto __flags = static_cast<__t>(_M_flags)
-			     | static_cast<__t>(__flag);
-	_M_flags = __wait_flags{ __flags };
+	_M_flags |= __flags;
 	return *this;
       }
 
-    private:
-      static int
-      constexpr _S_default_flags() noexcept
+      // Return an RAII type that calls __pool->_M_leave_wait() on destruction.
+      auto
+      _M_tracker(__waiter_pool_impl* __pool, const void* __addr) const
       {
-	using __t = underlying_type_t<__wait_flags>;
-	return static_cast<__t>(__wait_flags::__abi_version)
-		| static_cast<__t>(__wait_flags::__do_spin);
+	struct _Guard
+	{
+	  explicit _Guard(__waiter_pool_impl* __p) : _M_pool(__p) { }
+	  _Guard(const _Guard&) = delete;
+	  ~_Guard() { if (_M_pool) _M_pool->_M_leave_wait(); }
+	  __waiter_pool_impl* _M_pool;
+	};
+
+	if (*this & __wait_flags::__track_contention)
+	  {
+	    // Caller does not externally track contention,
+	    // so we want to increment+decrement __pool->_M_waiters
+
+	    // First make sure we have a pool for the address.
+	    if (!__pool)
+	      __pool = &__waiter_pool_impl::_S_impl_for(__addr);
+	    // Increment number of waiters:
+	    __pool->_M_enter_wait();
+	    // Returned _Guard will decrement it again on destruction.
+	    return _Guard{__pool};
+	  }
+	return _Guard{nullptr}; // For bare waits caller tracks waiters.
       }
 
+    private:
       template<typename _Tp>
-	static __wait_flags
-	constexpr _S_flags_for(const _Tp*, bool __bare_wait) noexcept
+	static constexpr __wait_flags
+	_S_flags_for(const _Tp*, bool __bare_wait) noexcept
 	{
-	  auto __res = _S_default_flags();
+	  using enum __wait_flags;
+	  __wait_flags __res = __abi_version | __do_spin;
 	  if (!__bare_wait)
-	    __res |= static_cast<int>(__wait_flags::__track_contention);
+	    __res |= __track_contention;
 	  if constexpr (!__platform_wait_uses_type<_Tp>)
-	    __res |= static_cast<int>(__wait_flags::__proxy_wait);
-	  return static_cast<__wait_flags>(__res);
+	    __res |= __proxy_wait;
+	  return __res;
 	}
 
+      // XXX what is this for? It's never used.
       template<typename _Tp>
 	static int
 	_S_memory_order_for(const _Tp*, int __order) noexcept
@@ -304,50 +321,39 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
     };
 
     using __wait_result_type = pair<bool, __platform_wait_t>;
+
     inline __wait_result_type
-    __spin_impl(const __platform_wait_t* __addr, const __wait_args_base* __args)
+    __spin_impl(const __platform_wait_t* __addr, const __wait_args* __args)
     {
       __platform_wait_t __val;
       for (auto __i = 0; __i < __atomic_spin_count; ++__i)
 	{
 	  __atomic_load(__addr, &__val, __args->_M_order);
 	  if (__val != __args->_M_old)
-	    return make_pair(true, __val);
+	    return { true, __val };
 	  if (__i < __atomic_spin_count_relax)
 	    __detail::__thread_relax();
 	  else
 	    __detail::__thread_yield();
 	}
-      return make_pair(false, __val);
+      return { false, __val };
     }
 
     inline __wait_result_type
-    __wait_impl(const __platform_wait_t* __addr, const __wait_args_base* __a)
+    __wait_impl(const __platform_wait_t* __addr, const __wait_args* __a)
     {
       __wait_args __args{ *__a };
-#ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
       __waiter_pool_impl* __pool = nullptr;
-#else
-      // if we don't have __platform_wait, we always need the side-table
-      __waiter_pool_impl* __pool = &__waiter_pool_impl::_S_impl_for(__addr);
-#endif
 
-      __platform_wait_t* __wait_addr;
-      __platform_wait_t __old;
+      const __platform_wait_t* __wait_addr;
       if (__args & __wait_flags::__proxy_wait)
 	{
-#ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
 	  __pool = &__waiter_pool_impl::_S_impl_for(__addr);
-#endif
 	  __wait_addr = &__pool->_M_ver;
-	  __atomic_load(__wait_addr, &__old, __args._M_order);
+	  __atomic_load(__wait_addr, &__args._M_old, __args._M_order);
 	}
       else
-	{
-	  __wait_addr = const_cast<__platform_wait_t*>(__addr);
-	  __old = __args._M_old;
-	}
-
+	__wait_addr = __addr;
 
       if (__args & __wait_flags::__do_spin)
 	{
@@ -358,86 +364,75 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	    return __res;
 	}
 
-      if (!(__args & __wait_flags::__track_contention))
-	{
-	  // caller does not externally track contention
-#ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
-	  __pool = (__pool == nullptr) ? &__waiter_pool_impl::_S_impl_for(__addr)
-				       : __pool;
-#endif
-	  __pool->_M_enter_wait();
-	}
+      auto __tracker = __args._M_tracker(__pool, __addr);
 
-      __wait_result_type __res;
 #ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
       __platform_wait(__wait_addr, __args._M_old);
-      __res = make_pair(false, __args._M_old);
+      return { false, __args._M_old };
 #else
       __platform_wait_t __val;
       __atomic_load(__wait_addr, &__val, __args._M_order);
       if (__val == __args._M_old)
 	{
+	  if (!__pool)
+	    __pool = &__waiter_pool_impl::_S_impl_for(__addr);
 	  lock_guard<mutex> __l{ __pool->_M_mtx };
 	  __atomic_load(__wait_addr, &__val, __args._M_order);
 	  if (__val == __args._M_old)
 	    __pool->_M_cv.wait(__pool->_M_mtx);
 	}
-      __res = make_pair(false, __val);
+      return { false, __val };
 #endif
-
-      if (!(__args & __wait_flags::__track_contention))
-	// caller does not externally track contention
-	__pool->_M_leave_wait();
-      return __res;
     }
 
     inline void
     __notify_impl(const __platform_wait_t* __addr, [[maybe_unused]] bool __all,
-		  const __wait_args_base* __a)
+		  const __wait_args* __a)
     {
-      __wait_args __args{ __a };
-#ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
+      __wait_args __args{ *__a };
       __waiter_pool_impl* __pool = nullptr;
-#else
-      // if we don't have __platform_notify, we always need the side-table
-      __waiter_pool_impl* __pool = &__waiter_pool_impl::_S_impl_for(__addr);
-#endif
 
-      if (!(__args & __wait_flags::__track_contention))
+      if (__args & __wait_flags::__track_contention)
 	{
-#ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
 	  __pool = &__waiter_pool_impl::_S_impl_for(__addr);
-#endif
 	  if (!__pool->_M_waiting())
 	    return;
 	}
 
-      __platform_wait_t* __wait_addr;
+      const __platform_wait_t* __wait_addr;
       if (__args & __wait_flags::__proxy_wait)
 	{
-#ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
-	   __pool = (__pool == nullptr) ? &__waiter_pool_impl::_S_impl_for(__addr)
-					: __pool;
-#endif
-	   __wait_addr = &__pool->_M_ver;
-	   __atomic_fetch_add(__wait_addr, 1, __ATOMIC_RELAXED);
-	   __all = true;
-	 }
+	  if (!__pool)
+	    __pool = &__waiter_pool_impl::_S_impl_for(__addr);
+	  // Waiting for *__addr is actually done on the proxy's _M_ver.
+	  __wait_addr = &__pool->_M_ver;
+	  __atomic_fetch_add(&__pool->_M_ver, 1, __ATOMIC_RELAXED);
+	  // Because the proxy might be shared by several waiters waiting
+	  // on different atomic variables, we need to wake them all so
+	  // they can re-evaluate their conditions to see if they should
+	  // stop waiting or should wait again.
+	  __all = true;
+	}
+      else // Use the atomic variable's own address.
+	__wait_addr = __addr;
 
 #ifdef _GLIBCXX_HAVE_PLATFORM_WAIT
       __platform_notify(__wait_addr, __all);
 #else
+      if (!__pool)
+	__pool = &__waiter_pool_impl::_S_impl_for(__addr);
       lock_guard<mutex> __l{ __pool->_M_mtx };
       __pool->_M_cv.notify_all();
 #endif
     }
   } // namespace __detail
 
-  template<typename _Tp,
-	   typename _Pred, typename _ValFn>
+  // Wait on __addr while __pred(__vfn()) is false.
+  // If __bare_wait is false, increment a counter while waiting.
+  // For callers that keep their own count of waiters, use __bare_wait=true.
+  template<typename _Tp, typename _Pred, typename _ValFn>
     void
-    __atomic_wait_address(const _Tp* __addr,
-			  _Pred&& __pred, _ValFn&& __vfn,
+    __atomic_wait_address(const _Tp* __addr, _Pred&& __pred, _ValFn&& __vfn,
 			  bool __bare_wait = false) noexcept
     {
       const auto __wait_addr =
@@ -446,6 +441,13 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
       _Tp __val = __vfn();
       while (!__pred(__val))
 	{
+	  // If the wait is not proxied, set the value that we're waiting
+	  // to change.
+	  if constexpr (__platform_wait_uses_type<_Tp>)
+	    __args._M_old = __builtin_bit_cast(__detail::__platform_wait_t,
+					       __val);
+	  // Otherwise, it's a proxy wait and the proxy's _M_ver is used.
+
 	  __detail::__wait_impl(__wait_addr, &__args);
 	  __val = __vfn();
 	}
@@ -462,6 +464,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
     __detail::__wait_impl(__addr, &__args);
   }
 
+  // Wait on __addr while __vfn() == __old is true.
   template<typename _Tp, typename _ValFn>
     void
     __atomic_wait_address_v(const _Tp* __addr, _Tp __old,
diff --git a/libstdc++-v3/include/std/latch b/libstdc++-v3/include/std/latch
index de0afd8989bf..c81a6631d53f 100644
--- a/libstdc++-v3/include/std/latch
+++ b/libstdc++-v3/include/std/latch
@@ -78,8 +78,12 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
     _GLIBCXX_ALWAYS_INLINE void
     wait() const noexcept
     {
-      auto const __vfn = [this] { return this->try_wait(); };
-      auto const __pred = [this](bool __b) { return __b; };
+      auto const __vfn = [this] {
+	return __atomic_impl::load(&_M_a, memory_order::acquire);
+      };
+      auto const __pred = [](__detail::__platform_wait_t __v) {
+	return __v == 0;
+      };
       std::__atomic_wait_address(&_M_a, __pred, __vfn);
     }
 
diff --git a/libstdc++-v3/testsuite/29_atomics/atomic_integral/wait_notify.cc b/libstdc++-v3/testsuite/29_atomics/atomic_integral/wait_notify.cc
index c7f8779e4fb2..c6b7f637a2b4 100644
--- a/libstdc++-v3/testsuite/29_atomics/atomic_integral/wait_notify.cc
+++ b/libstdc++-v3/testsuite/29_atomics/atomic_integral/wait_notify.cc
@@ -33,12 +33,16 @@  template<typename Tp>
     std::atomic<Tp> a{ Tp(1) };
     VERIFY( a.load() == Tp(1) );
     a.wait( Tp(0) );
+    std::atomic<bool> b{false};
     std::thread t([&]
       {
+	b.store(true, std::memory_order_relaxed);
         a.store(Tp(0));
         a.notify_one();
       });
     a.wait(Tp(1));
+    // Ensure we actually waited until a.store(0) happened:
+    VERIFY( b.load(std::memory_order_relaxed) );
     t.join();
   }