libstdc++: Use memchr to optimize std::find [PR88545]
Checks
Context |
Check |
Description |
linaro-tcwg-bot/tcwg_gcc_build--master-aarch64 |
success
|
Testing passed
|
linaro-tcwg-bot/tcwg_gcc_build--master-arm |
success
|
Testing passed
|
linaro-tcwg-bot/tcwg_gcc_check--master-arm |
fail
|
Testing failed
|
Commit Message
I plan to push this after testing finishes.
-- >8 --
This optimizes std::find to use memchr when searching for an integer in
a range of bytes.
libstdc++-v3/ChangeLog:
PR libstdc++/88545
PR libstdc++/115040
* include/bits/cpp_type_traits.h (__can_use_memchr_for_find):
New variable template.
* include/bits/ranges_util.h (__find_fn): Use memchr when
possible.
* include/bits/stl_algo.h (find): Likewise.
* testsuite/25_algorithms/find/bytes.cc: New test.
---
libstdc++-v3/include/bits/cpp_type_traits.h | 13 ++
libstdc++-v3/include/bits/ranges_util.h | 17 +++
libstdc++-v3/include/bits/stl_algo.h | 35 ++++++
.../testsuite/25_algorithms/find/bytes.cc | 112 ++++++++++++++++++
4 files changed, 177 insertions(+)
create mode 100644 libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
Comments
This patch needs a tweak to not try to use memchr during constant
evaluation, i.e. check std::is_constant_evaluated().
On Wed, 5 Jun 2024 at 16:34, Jonathan Wakely <jwakely@redhat.com> wrote:
>
> I plan to push this after testing finishes.
>
> -- >8 --
>
> This optimizes std::find to use memchr when searching for an integer in
> a range of bytes.
>
> libstdc++-v3/ChangeLog:
>
> PR libstdc++/88545
> PR libstdc++/115040
> * include/bits/cpp_type_traits.h (__can_use_memchr_for_find):
> New variable template.
> * include/bits/ranges_util.h (__find_fn): Use memchr when
> possible.
> * include/bits/stl_algo.h (find): Likewise.
> * testsuite/25_algorithms/find/bytes.cc: New test.
> ---
> libstdc++-v3/include/bits/cpp_type_traits.h | 13 ++
> libstdc++-v3/include/bits/ranges_util.h | 17 +++
> libstdc++-v3/include/bits/stl_algo.h | 35 ++++++
> .../testsuite/25_algorithms/find/bytes.cc | 112 ++++++++++++++++++
> 4 files changed, 177 insertions(+)
> create mode 100644 libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
>
> diff --git a/libstdc++-v3/include/bits/cpp_type_traits.h b/libstdc++-v3/include/bits/cpp_type_traits.h
> index 59f1a1875eb..466e6792a11 100644
> --- a/libstdc++-v3/include/bits/cpp_type_traits.h
> +++ b/libstdc++-v3/include/bits/cpp_type_traits.h
> @@ -35,6 +35,10 @@
> #pragma GCC system_header
>
> #include <bits/c++config.h>
> +#include <bits/version.h>
> +#if __glibcxx_type_trait_variable_templates
> +# include <type_traits> // is_same_v, is_integral_v
> +#endif
>
> //
> // This file provides some compile-time information about various types.
> @@ -589,6 +593,15 @@ __INT_N(__GLIBCXX_TYPE_INT_N_3)
> { static constexpr bool __value = false; };
> #endif
>
> +#if __glibcxx_type_trait_variable_templates
> + template<typename _ValT, typename _Tp>
> + constexpr bool __can_use_memchr_for_find
> + // Can only use memchr to search for narrow characters and std::byte.
> + = __is_byte<_ValT>::__value
> + // And only if the value to find is an integer (or is also std::byte).
> + && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>);
> +#endif
> +
> //
> // Move iterator type
> //
> diff --git a/libstdc++-v3/include/bits/ranges_util.h b/libstdc++-v3/include/bits/ranges_util.h
> index 9b79c3a229d..7247e89a79d 100644
> --- a/libstdc++-v3/include/bits/ranges_util.h
> +++ b/libstdc++-v3/include/bits/ranges_util.h
> @@ -34,6 +34,7 @@
> # include <bits/ranges_base.h>
> # include <bits/utility.h>
> # include <bits/invoke.h>
> +# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find
>
> #ifdef __glibcxx_ranges
> namespace std _GLIBCXX_VISIBILITY(default)
> @@ -494,6 +495,22 @@ namespace ranges
> operator()(_Iter __first, _Sent __last,
> const _Tp& __value, _Proj __proj = {}) const
> {
> + if constexpr (is_same_v<_Proj, identity>)
> + if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>)
> + if constexpr (sized_sentinel_for<_Sent, _Iter>)
> + if constexpr (contiguous_iterator<_Iter>)
> + {
> + auto __n = __last - __first;
> + if (__n > 0)
> + {
> + const int __ival = static_cast<int>(__value);
> + const void* __p0 = std::to_address(__first);
> + if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
> + __n = (const char*)__p1 - (const char*)__p0;
> + }
> + return __first + __n;
> + }
> +
> while (__first != __last
> && !(std::__invoke(__proj, *__first) == __value))
> ++__first;
> diff --git a/libstdc++-v3/include/bits/stl_algo.h b/libstdc++-v3/include/bits/stl_algo.h
> index 1a996aa61da..eba3157a480 100644
> --- a/libstdc++-v3/include/bits/stl_algo.h
> +++ b/libstdc++-v3/include/bits/stl_algo.h
> @@ -3836,6 +3836,7 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
> * such that @c *i == @p __val, or @p __last if no such iterator exists.
> */
> template<typename _InputIterator, typename _Tp>
> + _GLIBCXX_NODISCARD
> _GLIBCXX20_CONSTEXPR
> inline _InputIterator
> find(_InputIterator __first, _InputIterator __last,
> @@ -3846,6 +3847,40 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
> __glibcxx_function_requires(_EqualOpConcept<
> typename iterator_traits<_InputIterator>::value_type, _Tp>)
> __glibcxx_requires_valid_range(__first, __last);
> +
> +#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates
> + using _ValT = typename iterator_traits<_InputIterator>::value_type;
> + if constexpr (__can_use_memchr_for_find<_ValT, _Tp>)
> + {
> + // If converting the value to the 1-byte value_type alters its value,
> + // then it would not be found by std::find using equality comparison.
> + // We need to check this here, because otherwise something like
> + // memchr("a", 'a'+256, 1) would give a false positive match.
> + if (static_cast<_ValT>(__val) != __val)
> + return __last;
> +
> + const void* __p0 = nullptr;
> + if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>)
> + __p0 = std::__niter_base(__first);
> +#if __cpp_lib_concepts
> + else if constexpr (contiguous_iterator<_InputIterator>)
> + __p0 = std::to_address(__first);
> +#endif
> +
> + if (__p0)
> + {
> + auto __n = std::distance(__first, __last);
> + if (__n > 0)
> + {
> + const int __ival = static_cast<int>(__val);
> + if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
> + return __first + ((const char*)__p1 - (const char*)__p0);
> + }
> + return __last;
> + }
> + }
> +#endif
> +
> return std::__find_if(__first, __last,
> __gnu_cxx::__ops::__iter_equals_val(__val));
> }
> diff --git a/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
> new file mode 100644
> index 00000000000..ac189dac65f
> --- /dev/null
> +++ b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
> @@ -0,0 +1,112 @@
> +// { dg-do run }
> +
> +#include <algorithm>
> +#include <cstddef> // std::byte
> +#include <testsuite_hooks.h>
> +
> +// PR libstdc++/88545 made std::find use memchr as an optimization.
> +// This test verifies that it didn't change any semantics.
> +
> +template<typename C>
> +void
> +test_char()
> +{
> + const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' };
> + const C* end = a + sizeof(a);
> + const C* res = std::find(a, end, a[0]);
> + VERIFY( res == a );
> + res = std::find(a, end, a[2]);
> + VERIFY( res == a+2 );
> + res = std::find(a, end, a[0] + 256);
> + VERIFY( res == end );
> + res = std::find(a, end, a[0] - 256);
> + VERIFY( res == end );
> + res = std::find(a, end, 256);
> + VERIFY( res == end );
> +
> +#ifdef __cpp_lib_ranges
> + res = std::ranges::find(a, a[0]);
> + VERIFY( res == a );
> + res = std::ranges::find(a, a[2]);
> + VERIFY( res == a+2 );
> + res = std::ranges::find(a, a[0] + 256);
> + VERIFY( res == end );
> + res = std::ranges::find(a, a[0] - 256);
> + VERIFY( res == end );
> + res = std::ranges::find(a, 256);
> + VERIFY( res == end );
> +#endif
> +}
> +
> +// Trivial type of size 1, with custom equality.
> +struct S {
> + bool operator==(const S&) const { return true; };
> + char c;
> +};
> +
> +// Trivial type of size 1, with custom equality.
> +enum E
> +#if __cplusplus >= 201103L
> +: unsigned char
> +#endif
> +{ e1 = 1, e255 = 255 };
> +
> +bool operator==(E l, E r) { return (l % 3) == (r % 3); }
> +
> +struct X { char c; };
> +bool operator==(X, char) { return false; }
> +bool operator==(char, X) { return false; }
> +
> +void
> +test_non_characters()
> +{
> + S s[3] = { {'a'}, {'b'}, {'c'} };
> + S sx = {'x'};
> + S* sres = std::find(s, s+3, sx);
> + VERIFY( sres == s ); // memchr optimization would not find a match
> +
> + E e[3] = { E(1), E(2), E(3) };
> + E* eres = std::find(e, e+3, E(4));
> + VERIFY( eres == e ); // memchr optimization would not find a match
> +
> + char x[1] = { 'x' };
> + X xx = { 'x' };
> + char* xres = std::find(x, x+1, xx);
> + VERIFY( xres == x+1 ); // memchr optimization would find a match
> +
> +#ifdef __cpp_lib_byte
> + std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} };
> + std::byte* bres = std::find(b, b+4, std::byte{4});
> + VERIFY( bres == b+4 );
> + bres = std::find(b, b+2, std::byte{3});
> + VERIFY( bres == b+2 );
> + bres = std::find(b, b+3, std::byte{3});
> + VERIFY( bres == b+3 );
> +#endif
> +
> +#ifdef __cpp_lib_ranges
> + sres = std::ranges::find(s, sx);
> + VERIFY( sres == s );
> +
> + eres = std::ranges::find(e, e+3, E(4));
> + VERIFY( eres == e );
> +
> + xres = std::ranges::find(x, xx);
> + VERIFY( xres == std::ranges::end(x) );
> +
> + bres = std::ranges::find(b, std::byte{4});
> + VERIFY( bres == b+4 );
> + bres = std::ranges::find(b, b+2, std::byte{3});
> + VERIFY( bres == b+2 );
> + bres = std::ranges::find(b, b+3, std::byte{3});
> + VERIFY( bres == b+3 );
> +#endif
> +}
> +
> +int main()
> +{
> + test_char<char>();
> + test_char<signed char>();
> + test_char<unsigned char>();
> + test_non_characters();
> +}
> --
> 2.45.1
>
@@ -35,6 +35,10 @@
#pragma GCC system_header
#include <bits/c++config.h>
+#include <bits/version.h>
+#if __glibcxx_type_trait_variable_templates
+# include <type_traits> // is_same_v, is_integral_v
+#endif
//
// This file provides some compile-time information about various types.
@@ -589,6 +593,15 @@ __INT_N(__GLIBCXX_TYPE_INT_N_3)
{ static constexpr bool __value = false; };
#endif
+#if __glibcxx_type_trait_variable_templates
+ template<typename _ValT, typename _Tp>
+ constexpr bool __can_use_memchr_for_find
+ // Can only use memchr to search for narrow characters and std::byte.
+ = __is_byte<_ValT>::__value
+ // And only if the value to find is an integer (or is also std::byte).
+ && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>);
+#endif
+
//
// Move iterator type
//
@@ -34,6 +34,7 @@
# include <bits/ranges_base.h>
# include <bits/utility.h>
# include <bits/invoke.h>
+# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find
#ifdef __glibcxx_ranges
namespace std _GLIBCXX_VISIBILITY(default)
@@ -494,6 +495,22 @@ namespace ranges
operator()(_Iter __first, _Sent __last,
const _Tp& __value, _Proj __proj = {}) const
{
+ if constexpr (is_same_v<_Proj, identity>)
+ if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>)
+ if constexpr (sized_sentinel_for<_Sent, _Iter>)
+ if constexpr (contiguous_iterator<_Iter>)
+ {
+ auto __n = __last - __first;
+ if (__n > 0)
+ {
+ const int __ival = static_cast<int>(__value);
+ const void* __p0 = std::to_address(__first);
+ if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+ __n = (const char*)__p1 - (const char*)__p0;
+ }
+ return __first + __n;
+ }
+
while (__first != __last
&& !(std::__invoke(__proj, *__first) == __value))
++__first;
@@ -3836,6 +3836,7 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
* such that @c *i == @p __val, or @p __last if no such iterator exists.
*/
template<typename _InputIterator, typename _Tp>
+ _GLIBCXX_NODISCARD
_GLIBCXX20_CONSTEXPR
inline _InputIterator
find(_InputIterator __first, _InputIterator __last,
@@ -3846,6 +3847,40 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
__glibcxx_function_requires(_EqualOpConcept<
typename iterator_traits<_InputIterator>::value_type, _Tp>)
__glibcxx_requires_valid_range(__first, __last);
+
+#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates
+ using _ValT = typename iterator_traits<_InputIterator>::value_type;
+ if constexpr (__can_use_memchr_for_find<_ValT, _Tp>)
+ {
+ // If converting the value to the 1-byte value_type alters its value,
+ // then it would not be found by std::find using equality comparison.
+ // We need to check this here, because otherwise something like
+ // memchr("a", 'a'+256, 1) would give a false positive match.
+ if (static_cast<_ValT>(__val) != __val)
+ return __last;
+
+ const void* __p0 = nullptr;
+ if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>)
+ __p0 = std::__niter_base(__first);
+#if __cpp_lib_concepts
+ else if constexpr (contiguous_iterator<_InputIterator>)
+ __p0 = std::to_address(__first);
+#endif
+
+ if (__p0)
+ {
+ auto __n = std::distance(__first, __last);
+ if (__n > 0)
+ {
+ const int __ival = static_cast<int>(__val);
+ if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+ return __first + ((const char*)__p1 - (const char*)__p0);
+ }
+ return __last;
+ }
+ }
+#endif
+
return std::__find_if(__first, __last,
__gnu_cxx::__ops::__iter_equals_val(__val));
}
new file mode 100644
@@ -0,0 +1,112 @@
+// { dg-do run }
+
+#include <algorithm>
+#include <cstddef> // std::byte
+#include <testsuite_hooks.h>
+
+// PR libstdc++/88545 made std::find use memchr as an optimization.
+// This test verifies that it didn't change any semantics.
+
+template<typename C>
+void
+test_char()
+{
+ const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' };
+ const C* end = a + sizeof(a);
+ const C* res = std::find(a, end, a[0]);
+ VERIFY( res == a );
+ res = std::find(a, end, a[2]);
+ VERIFY( res == a+2 );
+ res = std::find(a, end, a[0] + 256);
+ VERIFY( res == end );
+ res = std::find(a, end, a[0] - 256);
+ VERIFY( res == end );
+ res = std::find(a, end, 256);
+ VERIFY( res == end );
+
+#ifdef __cpp_lib_ranges
+ res = std::ranges::find(a, a[0]);
+ VERIFY( res == a );
+ res = std::ranges::find(a, a[2]);
+ VERIFY( res == a+2 );
+ res = std::ranges::find(a, a[0] + 256);
+ VERIFY( res == end );
+ res = std::ranges::find(a, a[0] - 256);
+ VERIFY( res == end );
+ res = std::ranges::find(a, 256);
+ VERIFY( res == end );
+#endif
+}
+
+// Trivial type of size 1, with custom equality.
+struct S {
+ bool operator==(const S&) const { return true; };
+ char c;
+};
+
+// Trivial type of size 1, with custom equality.
+enum E
+#if __cplusplus >= 201103L
+: unsigned char
+#endif
+{ e1 = 1, e255 = 255 };
+
+bool operator==(E l, E r) { return (l % 3) == (r % 3); }
+
+struct X { char c; };
+bool operator==(X, char) { return false; }
+bool operator==(char, X) { return false; }
+
+void
+test_non_characters()
+{
+ S s[3] = { {'a'}, {'b'}, {'c'} };
+ S sx = {'x'};
+ S* sres = std::find(s, s+3, sx);
+ VERIFY( sres == s ); // memchr optimization would not find a match
+
+ E e[3] = { E(1), E(2), E(3) };
+ E* eres = std::find(e, e+3, E(4));
+ VERIFY( eres == e ); // memchr optimization would not find a match
+
+ char x[1] = { 'x' };
+ X xx = { 'x' };
+ char* xres = std::find(x, x+1, xx);
+ VERIFY( xres == x+1 ); // memchr optimization would find a match
+
+#ifdef __cpp_lib_byte
+ std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} };
+ std::byte* bres = std::find(b, b+4, std::byte{4});
+ VERIFY( bres == b+4 );
+ bres = std::find(b, b+2, std::byte{3});
+ VERIFY( bres == b+2 );
+ bres = std::find(b, b+3, std::byte{3});
+ VERIFY( bres == b+3 );
+#endif
+
+#ifdef __cpp_lib_ranges
+ sres = std::ranges::find(s, sx);
+ VERIFY( sres == s );
+
+ eres = std::ranges::find(e, e+3, E(4));
+ VERIFY( eres == e );
+
+ xres = std::ranges::find(x, xx);
+ VERIFY( xres == std::ranges::end(x) );
+
+ bres = std::ranges::find(b, std::byte{4});
+ VERIFY( bres == b+4 );
+ bres = std::ranges::find(b, b+2, std::byte{3});
+ VERIFY( bres == b+2 );
+ bres = std::ranges::find(b, b+3, std::byte{3});
+ VERIFY( bres == b+3 );
+#endif
+}
+
+int main()
+{
+ test_char<char>();
+ test_char<signed char>();
+ test_char<unsigned char>();
+ test_non_characters();
+}