diff mbox series

libstdc++: Use memchr to optimize std::find [PR88545]

Message ID 20240605153233.119881-1-jwakely@redhat.com
State New
Headers show
Series libstdc++: Use memchr to optimize std::find [PR88545] | expand

Commit Message

Jonathan Wakely June 5, 2024, 3:31 p.m. UTC
I plan to push this after testing finishes.

-- >8 --

This optimizes std::find to use memchr when searching for an integer in
a range of bytes.

libstdc++-v3/ChangeLog:

	PR libstdc++/88545
	PR libstdc++/115040
	* include/bits/cpp_type_traits.h (__can_use_memchr_for_find):
	New variable template.
	* include/bits/ranges_util.h (__find_fn): Use memchr when
	possible.
	* include/bits/stl_algo.h (find): Likewise.
	* testsuite/25_algorithms/find/bytes.cc: New test.
---
 libstdc++-v3/include/bits/cpp_type_traits.h   |  13 ++
 libstdc++-v3/include/bits/ranges_util.h       |  17 +++
 libstdc++-v3/include/bits/stl_algo.h          |  35 ++++++
 .../testsuite/25_algorithms/find/bytes.cc     | 112 ++++++++++++++++++
 4 files changed, 177 insertions(+)
 create mode 100644 libstdc++-v3/testsuite/25_algorithms/find/bytes.cc

Comments

Jonathan Wakely June 5, 2024, 5:26 p.m. UTC | #1
This patch needs a tweak to not try to use memchr during constant
evaluation, i.e. check std::is_constant_evaluated().

On Wed, 5 Jun 2024 at 16:34, Jonathan Wakely <jwakely@redhat.com> wrote:
>
> I plan to push this after testing finishes.
>
> -- >8 --
>
> This optimizes std::find to use memchr when searching for an integer in
> a range of bytes.
>
> libstdc++-v3/ChangeLog:
>
>         PR libstdc++/88545
>         PR libstdc++/115040
>         * include/bits/cpp_type_traits.h (__can_use_memchr_for_find):
>         New variable template.
>         * include/bits/ranges_util.h (__find_fn): Use memchr when
>         possible.
>         * include/bits/stl_algo.h (find): Likewise.
>         * testsuite/25_algorithms/find/bytes.cc: New test.
> ---
>  libstdc++-v3/include/bits/cpp_type_traits.h   |  13 ++
>  libstdc++-v3/include/bits/ranges_util.h       |  17 +++
>  libstdc++-v3/include/bits/stl_algo.h          |  35 ++++++
>  .../testsuite/25_algorithms/find/bytes.cc     | 112 ++++++++++++++++++
>  4 files changed, 177 insertions(+)
>  create mode 100644 libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
>
> diff --git a/libstdc++-v3/include/bits/cpp_type_traits.h b/libstdc++-v3/include/bits/cpp_type_traits.h
> index 59f1a1875eb..466e6792a11 100644
> --- a/libstdc++-v3/include/bits/cpp_type_traits.h
> +++ b/libstdc++-v3/include/bits/cpp_type_traits.h
> @@ -35,6 +35,10 @@
>  #pragma GCC system_header
>
>  #include <bits/c++config.h>
> +#include <bits/version.h>
> +#if __glibcxx_type_trait_variable_templates
> +# include <type_traits> // is_same_v, is_integral_v
> +#endif
>
>  //
>  // This file provides some compile-time information about various types.
> @@ -589,6 +593,15 @@ __INT_N(__GLIBCXX_TYPE_INT_N_3)
>      { static constexpr bool __value = false; };
>  #endif
>
> +#if __glibcxx_type_trait_variable_templates
> +  template<typename _ValT, typename _Tp>
> +    constexpr bool __can_use_memchr_for_find
> +    // Can only use memchr to search for narrow characters and std::byte.
> +      = __is_byte<_ValT>::__value
> +       // And only if the value to find is an integer (or is also std::byte).
> +         && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>);
> +#endif
> +
>    //
>    // Move iterator type
>    //
> diff --git a/libstdc++-v3/include/bits/ranges_util.h b/libstdc++-v3/include/bits/ranges_util.h
> index 9b79c3a229d..7247e89a79d 100644
> --- a/libstdc++-v3/include/bits/ranges_util.h
> +++ b/libstdc++-v3/include/bits/ranges_util.h
> @@ -34,6 +34,7 @@
>  # include <bits/ranges_base.h>
>  # include <bits/utility.h>
>  # include <bits/invoke.h>
> +# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find
>
>  #ifdef __glibcxx_ranges
>  namespace std _GLIBCXX_VISIBILITY(default)
> @@ -494,6 +495,22 @@ namespace ranges
>        operator()(_Iter __first, _Sent __last,
>                  const _Tp& __value, _Proj __proj = {}) const
>        {
> +       if constexpr (is_same_v<_Proj, identity>)
> +         if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>)
> +           if constexpr (sized_sentinel_for<_Sent, _Iter>)
> +             if constexpr (contiguous_iterator<_Iter>)
> +               {
> +                 auto __n = __last - __first;
> +                 if (__n > 0)
> +                   {
> +                     const int __ival = static_cast<int>(__value);
> +                     const void* __p0 = std::to_address(__first);
> +                     if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
> +                       __n = (const char*)__p1 - (const char*)__p0;
> +                   }
> +                 return __first + __n;
> +               }
> +
>         while (__first != __last
>             && !(std::__invoke(__proj, *__first) == __value))
>           ++__first;
> diff --git a/libstdc++-v3/include/bits/stl_algo.h b/libstdc++-v3/include/bits/stl_algo.h
> index 1a996aa61da..eba3157a480 100644
> --- a/libstdc++-v3/include/bits/stl_algo.h
> +++ b/libstdc++-v3/include/bits/stl_algo.h
> @@ -3836,6 +3836,7 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
>     *  such that @c *i == @p __val, or @p __last if no such iterator exists.
>    */
>    template<typename _InputIterator, typename _Tp>
> +    _GLIBCXX_NODISCARD
>      _GLIBCXX20_CONSTEXPR
>      inline _InputIterator
>      find(_InputIterator __first, _InputIterator __last,
> @@ -3846,6 +3847,40 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
>        __glibcxx_function_requires(_EqualOpConcept<
>                 typename iterator_traits<_InputIterator>::value_type, _Tp>)
>        __glibcxx_requires_valid_range(__first, __last);
> +
> +#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates
> +      using _ValT = typename iterator_traits<_InputIterator>::value_type;
> +      if constexpr (__can_use_memchr_for_find<_ValT, _Tp>)
> +       {
> +         // If converting the value to the 1-byte value_type alters its value,
> +         // then it would not be found by std::find using equality comparison.
> +         // We need to check this here, because otherwise something like
> +         // memchr("a", 'a'+256, 1) would give a false positive match.
> +         if (static_cast<_ValT>(__val) != __val)
> +           return __last;
> +
> +         const void* __p0 = nullptr;
> +         if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>)
> +           __p0 = std::__niter_base(__first);
> +#if __cpp_lib_concepts
> +         else if constexpr (contiguous_iterator<_InputIterator>)
> +           __p0 = std::to_address(__first);
> +#endif
> +
> +         if (__p0)
> +           {
> +             auto __n = std::distance(__first, __last);
> +             if (__n > 0)
> +               {
> +                 const int __ival = static_cast<int>(__val);
> +                 if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
> +                   return __first + ((const char*)__p1 - (const char*)__p0);
> +               }
> +             return __last;
> +           }
> +       }
> +#endif
> +
>        return std::__find_if(__first, __last,
>                             __gnu_cxx::__ops::__iter_equals_val(__val));
>      }
> diff --git a/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
> new file mode 100644
> index 00000000000..ac189dac65f
> --- /dev/null
> +++ b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
> @@ -0,0 +1,112 @@
> +// { dg-do run }
> +
> +#include <algorithm>
> +#include <cstddef> // std::byte
> +#include <testsuite_hooks.h>
> +
> +// PR libstdc++/88545 made std::find use memchr as an optimization.
> +// This test verifies that it didn't change any semantics.
> +
> +template<typename C>
> +void
> +test_char()
> +{
> +  const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' };
> +  const C* end = a + sizeof(a);
> +  const C* res = std::find(a, end, a[0]);
> +  VERIFY( res == a );
> +  res = std::find(a, end, a[2]);
> +  VERIFY( res == a+2 );
> +  res = std::find(a, end, a[0] + 256);
> +  VERIFY( res == end );
> +  res = std::find(a, end, a[0] - 256);
> +  VERIFY( res == end );
> +  res = std::find(a, end, 256);
> +  VERIFY( res == end );
> +
> +#ifdef __cpp_lib_ranges
> +  res = std::ranges::find(a, a[0]);
> +  VERIFY( res == a );
> +  res = std::ranges::find(a, a[2]);
> +  VERIFY( res == a+2 );
> +  res = std::ranges::find(a, a[0] + 256);
> +  VERIFY( res == end );
> +  res = std::ranges::find(a, a[0] - 256);
> +  VERIFY( res == end );
> +  res = std::ranges::find(a, 256);
> +  VERIFY( res == end );
> +#endif
> +}
> +
> +// Trivial type of size 1, with custom equality.
> +struct S {
> +  bool operator==(const S&) const { return true; };
> +  char c;
> +};
> +
> +// Trivial type of size 1, with custom equality.
> +enum E
> +#if __cplusplus >= 201103L
> +: unsigned char
> +#endif
> +{ e1 = 1, e255 = 255 };
> +
> +bool operator==(E l, E r) { return (l % 3) == (r % 3); }
> +
> +struct X { char c; };
> +bool operator==(X, char) { return false; }
> +bool operator==(char, X) { return false; }
> +
> +void
> +test_non_characters()
> +{
> +  S s[3] = { {'a'}, {'b'}, {'c'} };
> +  S sx = {'x'};
> +  S* sres = std::find(s, s+3, sx);
> +  VERIFY( sres == s ); // memchr optimization would not find a match
> +
> +  E e[3] = { E(1), E(2), E(3) };
> +  E* eres = std::find(e, e+3, E(4));
> +  VERIFY( eres == e ); // memchr optimization would not find a match
> +
> +  char x[1] = { 'x' };
> +  X xx = { 'x' };
> +  char* xres = std::find(x, x+1, xx);
> +  VERIFY( xres == x+1 ); // memchr optimization would find a match
> +
> +#ifdef __cpp_lib_byte
> +  std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} };
> +  std::byte* bres = std::find(b, b+4, std::byte{4});
> +  VERIFY( bres == b+4 );
> +  bres = std::find(b, b+2, std::byte{3});
> +  VERIFY( bres == b+2 );
> +  bres = std::find(b, b+3, std::byte{3});
> +  VERIFY( bres == b+3 );
> +#endif
> +
> +#ifdef __cpp_lib_ranges
> +  sres = std::ranges::find(s, sx);
> +  VERIFY( sres == s );
> +
> +  eres = std::ranges::find(e, e+3, E(4));
> +  VERIFY( eres == e );
> +
> +  xres = std::ranges::find(x, xx);
> +  VERIFY( xres == std::ranges::end(x) );
> +
> +  bres = std::ranges::find(b, std::byte{4});
> +  VERIFY( bres == b+4 );
> +  bres = std::ranges::find(b, b+2, std::byte{3});
> +  VERIFY( bres == b+2 );
> +  bres = std::ranges::find(b, b+3, std::byte{3});
> +  VERIFY( bres == b+3 );
> +#endif
> +}
> +
> +int main()
> +{
> +  test_char<char>();
> +  test_char<signed char>();
> +  test_char<unsigned char>();
> +  test_non_characters();
> +}
> --
> 2.45.1
>
diff mbox series

Patch

diff --git a/libstdc++-v3/include/bits/cpp_type_traits.h b/libstdc++-v3/include/bits/cpp_type_traits.h
index 59f1a1875eb..466e6792a11 100644
--- a/libstdc++-v3/include/bits/cpp_type_traits.h
+++ b/libstdc++-v3/include/bits/cpp_type_traits.h
@@ -35,6 +35,10 @@ 
 #pragma GCC system_header
 
 #include <bits/c++config.h>
+#include <bits/version.h>
+#if __glibcxx_type_trait_variable_templates
+# include <type_traits> // is_same_v, is_integral_v
+#endif
 
 //
 // This file provides some compile-time information about various types.
@@ -589,6 +593,15 @@  __INT_N(__GLIBCXX_TYPE_INT_N_3)
     { static constexpr bool __value = false; };
 #endif
 
+#if __glibcxx_type_trait_variable_templates
+  template<typename _ValT, typename _Tp>
+    constexpr bool __can_use_memchr_for_find
+    // Can only use memchr to search for narrow characters and std::byte.
+      = __is_byte<_ValT>::__value
+	// And only if the value to find is an integer (or is also std::byte).
+	  && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>);
+#endif
+
   //
   // Move iterator type
   //
diff --git a/libstdc++-v3/include/bits/ranges_util.h b/libstdc++-v3/include/bits/ranges_util.h
index 9b79c3a229d..7247e89a79d 100644
--- a/libstdc++-v3/include/bits/ranges_util.h
+++ b/libstdc++-v3/include/bits/ranges_util.h
@@ -34,6 +34,7 @@ 
 # include <bits/ranges_base.h>
 # include <bits/utility.h>
 # include <bits/invoke.h>
+# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find
 
 #ifdef __glibcxx_ranges
 namespace std _GLIBCXX_VISIBILITY(default)
@@ -494,6 +495,22 @@  namespace ranges
       operator()(_Iter __first, _Sent __last,
 		 const _Tp& __value, _Proj __proj = {}) const
       {
+	if constexpr (is_same_v<_Proj, identity>)
+	  if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>)
+	    if constexpr (sized_sentinel_for<_Sent, _Iter>)
+	      if constexpr (contiguous_iterator<_Iter>)
+		{
+		  auto __n = __last - __first;
+		  if (__n > 0)
+		    {
+		      const int __ival = static_cast<int>(__value);
+		      const void* __p0 = std::to_address(__first);
+		      if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+			__n = (const char*)__p1 - (const char*)__p0;
+		    }
+		  return __first + __n;
+		}
+
 	while (__first != __last
 	    && !(std::__invoke(__proj, *__first) == __value))
 	  ++__first;
diff --git a/libstdc++-v3/include/bits/stl_algo.h b/libstdc++-v3/include/bits/stl_algo.h
index 1a996aa61da..eba3157a480 100644
--- a/libstdc++-v3/include/bits/stl_algo.h
+++ b/libstdc++-v3/include/bits/stl_algo.h
@@ -3836,6 +3836,7 @@  _GLIBCXX_BEGIN_NAMESPACE_ALGO
    *  such that @c *i == @p __val, or @p __last if no such iterator exists.
   */
   template<typename _InputIterator, typename _Tp>
+    _GLIBCXX_NODISCARD
     _GLIBCXX20_CONSTEXPR
     inline _InputIterator
     find(_InputIterator __first, _InputIterator __last,
@@ -3846,6 +3847,40 @@  _GLIBCXX_BEGIN_NAMESPACE_ALGO
       __glibcxx_function_requires(_EqualOpConcept<
 		typename iterator_traits<_InputIterator>::value_type, _Tp>)
       __glibcxx_requires_valid_range(__first, __last);
+
+#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates
+      using _ValT = typename iterator_traits<_InputIterator>::value_type;
+      if constexpr (__can_use_memchr_for_find<_ValT, _Tp>)
+	{
+	  // If converting the value to the 1-byte value_type alters its value,
+	  // then it would not be found by std::find using equality comparison.
+	  // We need to check this here, because otherwise something like
+	  // memchr("a", 'a'+256, 1) would give a false positive match.
+	  if (static_cast<_ValT>(__val) != __val)
+	    return __last;
+
+	  const void* __p0 = nullptr;
+	  if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>)
+	    __p0 = std::__niter_base(__first);
+#if __cpp_lib_concepts
+	  else if constexpr (contiguous_iterator<_InputIterator>)
+	    __p0 = std::to_address(__first);
+#endif
+
+	  if (__p0)
+	    {
+	      auto __n = std::distance(__first, __last);
+	      if (__n > 0)
+		{
+		  const int __ival = static_cast<int>(__val);
+		  if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+		    return __first + ((const char*)__p1 - (const char*)__p0);
+		}
+	      return __last;
+	    }
+	}
+#endif
+
       return std::__find_if(__first, __last,
 			    __gnu_cxx::__ops::__iter_equals_val(__val));
     }
diff --git a/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
new file mode 100644
index 00000000000..ac189dac65f
--- /dev/null
+++ b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
@@ -0,0 +1,112 @@ 
+// { dg-do run }
+
+#include <algorithm>
+#include <cstddef> // std::byte
+#include <testsuite_hooks.h>
+
+// PR libstdc++/88545 made std::find use memchr as an optimization.
+// This test verifies that it didn't change any semantics.
+
+template<typename C>
+void
+test_char()
+{
+  const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' };
+  const C* end = a + sizeof(a);
+  const C* res = std::find(a, end, a[0]);
+  VERIFY( res == a );
+  res = std::find(a, end, a[2]);
+  VERIFY( res == a+2 );
+  res = std::find(a, end, a[0] + 256);
+  VERIFY( res == end );
+  res = std::find(a, end, a[0] - 256);
+  VERIFY( res == end );
+  res = std::find(a, end, 256);
+  VERIFY( res == end );
+
+#ifdef __cpp_lib_ranges
+  res = std::ranges::find(a, a[0]);
+  VERIFY( res == a );
+  res = std::ranges::find(a, a[2]);
+  VERIFY( res == a+2 );
+  res = std::ranges::find(a, a[0] + 256);
+  VERIFY( res == end );
+  res = std::ranges::find(a, a[0] - 256);
+  VERIFY( res == end );
+  res = std::ranges::find(a, 256);
+  VERIFY( res == end );
+#endif
+}
+
+// Trivial type of size 1, with custom equality.
+struct S {
+  bool operator==(const S&) const { return true; };
+  char c;
+};
+
+// Trivial type of size 1, with custom equality.
+enum E
+#if __cplusplus >= 201103L
+: unsigned char
+#endif
+{ e1 = 1, e255 = 255 };
+
+bool operator==(E l, E r) { return (l % 3) == (r % 3); }
+
+struct X { char c; };
+bool operator==(X, char) { return false; }
+bool operator==(char, X) { return false; }
+
+void
+test_non_characters()
+{
+  S s[3] = { {'a'}, {'b'}, {'c'} };
+  S sx = {'x'};
+  S* sres = std::find(s, s+3, sx);
+  VERIFY( sres == s ); // memchr optimization would not find a match
+
+  E e[3] = { E(1), E(2), E(3) };
+  E* eres = std::find(e, e+3, E(4));
+  VERIFY( eres == e ); // memchr optimization would not find a match
+
+  char x[1] = { 'x' };
+  X xx = { 'x' };
+  char* xres = std::find(x, x+1, xx);
+  VERIFY( xres == x+1 ); // memchr optimization would find a match
+
+#ifdef __cpp_lib_byte
+  std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} };
+  std::byte* bres = std::find(b, b+4, std::byte{4});
+  VERIFY( bres == b+4 );
+  bres = std::find(b, b+2, std::byte{3});
+  VERIFY( bres == b+2 );
+  bres = std::find(b, b+3, std::byte{3});
+  VERIFY( bres == b+3 );
+#endif
+
+#ifdef __cpp_lib_ranges
+  sres = std::ranges::find(s, sx);
+  VERIFY( sres == s );
+
+  eres = std::ranges::find(e, e+3, E(4));
+  VERIFY( eres == e );
+
+  xres = std::ranges::find(x, xx);
+  VERIFY( xres == std::ranges::end(x) );
+
+  bres = std::ranges::find(b, std::byte{4});
+  VERIFY( bres == b+4 );
+  bres = std::ranges::find(b, b+2, std::byte{3});
+  VERIFY( bres == b+2 );
+  bres = std::ranges::find(b, b+3, std::byte{3});
+  VERIFY( bres == b+3 );
+#endif
+}
+
+int main()
+{
+  test_char<char>();
+  test_char<signed char>();
+  test_char<unsigned char>();
+  test_non_characters();
+}