diff mbox series

[v2] libstdc++: Use memchr to optimize std::find [PR88545]

Message ID 20240605174953.143341-1-jwakely@redhat.com
State New
Headers show
Series [v2] libstdc++: Use memchr to optimize std::find [PR88545] | expand

Commit Message

Jonathan Wakely June 5, 2024, 5:47 p.m. UTC
This v2 patch uses std::is_constant_evaluated() so the algos don't try
to use memchr in constant expressions, and removes the
_GLIBCXX_NODISCARD I added to std::find (it should still be added, but
as a separate patch).

Tamar has asked that I don't push this until he compares the performance
to a vectorized std::find_if, so I'll wait.

-- >8 --

This optimizes std::find to use memchr when searching for an integer in
a range of bytes.

libstdc++-v3/ChangeLog:

	PR libstdc++/88545
	PR libstdc++/115040
	* include/bits/cpp_type_traits.h (__can_use_memchr_for_find):
	New variable template.
	* include/bits/ranges_util.h (__find_fn): Use memchr when
	possible.
	* include/bits/stl_algo.h (find): Likewise.
	* testsuite/25_algorithms/find/bytes.cc: New test.
---
 libstdc++-v3/include/bits/cpp_type_traits.h   |  13 ++
 libstdc++-v3/include/bits/ranges_util.h       |  18 +++
 libstdc++-v3/include/bits/stl_algo.h          |  33 +++++
 .../testsuite/25_algorithms/find/bytes.cc     | 135 ++++++++++++++++++
 4 files changed, 199 insertions(+)
 create mode 100644 libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
diff mbox series

Patch

diff --git a/libstdc++-v3/include/bits/cpp_type_traits.h b/libstdc++-v3/include/bits/cpp_type_traits.h
index 59f1a1875eb..466e6792a11 100644
--- a/libstdc++-v3/include/bits/cpp_type_traits.h
+++ b/libstdc++-v3/include/bits/cpp_type_traits.h
@@ -35,6 +35,10 @@ 
 #pragma GCC system_header
 
 #include <bits/c++config.h>
+#include <bits/version.h>
+#if __glibcxx_type_trait_variable_templates
+# include <type_traits> // is_same_v, is_integral_v
+#endif
 
 //
 // This file provides some compile-time information about various types.
@@ -589,6 +593,15 @@  __INT_N(__GLIBCXX_TYPE_INT_N_3)
     { static constexpr bool __value = false; };
 #endif
 
+#if __glibcxx_type_trait_variable_templates
+  template<typename _ValT, typename _Tp>
+    constexpr bool __can_use_memchr_for_find
+    // Can only use memchr to search for narrow characters and std::byte.
+      = __is_byte<_ValT>::__value
+	// And only if the value to find is an integer (or is also std::byte).
+	  && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>);
+#endif
+
   //
   // Move iterator type
   //
diff --git a/libstdc++-v3/include/bits/ranges_util.h b/libstdc++-v3/include/bits/ranges_util.h
index 9b79c3a229d..03239fd8af6 100644
--- a/libstdc++-v3/include/bits/ranges_util.h
+++ b/libstdc++-v3/include/bits/ranges_util.h
@@ -34,6 +34,7 @@ 
 # include <bits/ranges_base.h>
 # include <bits/utility.h>
 # include <bits/invoke.h>
+# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find
 
 #ifdef __glibcxx_ranges
 namespace std _GLIBCXX_VISIBILITY(default)
@@ -494,6 +495,23 @@  namespace ranges
       operator()(_Iter __first, _Sent __last,
 		 const _Tp& __value, _Proj __proj = {}) const
       {
+	if constexpr (is_same_v<_Proj, identity>)
+	  if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>)
+	    if constexpr (sized_sentinel_for<_Sent, _Iter>)
+	      if constexpr (contiguous_iterator<_Iter>)
+		if (!is_constant_evaluated())
+		  {
+		    auto __n = __last - __first;
+		    if (__n > 0)
+		      {
+			const int __ival = static_cast<int>(__value);
+			const void* __p0 = std::to_address(__first);
+			if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+			  __n = (const char*)__p1 - (const char*)__p0;
+		      }
+		    return __first + __n;
+		  }
+
 	while (__first != __last
 	    && !(std::__invoke(__proj, *__first) == __value))
 	  ++__first;
diff --git a/libstdc++-v3/include/bits/stl_algo.h b/libstdc++-v3/include/bits/stl_algo.h
index 1a996aa61da..4daaf60f289 100644
--- a/libstdc++-v3/include/bits/stl_algo.h
+++ b/libstdc++-v3/include/bits/stl_algo.h
@@ -3846,6 +3846,39 @@  _GLIBCXX_BEGIN_NAMESPACE_ALGO
       __glibcxx_function_requires(_EqualOpConcept<
 		typename iterator_traits<_InputIterator>::value_type, _Tp>)
       __glibcxx_requires_valid_range(__first, __last);
+
+#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates
+      using _ValT = typename iterator_traits<_InputIterator>::value_type;
+      if constexpr (__can_use_memchr_for_find<_ValT, _Tp>)
+	{
+	  // If converting the value to the 1-byte value_type alters its value,
+	  // then it would not be found by std::find using equality comparison.
+	  // We need to check this here, because otherwise something like
+	  // memchr("a", 'a'+256, 1) would give a false positive match.
+	  if (static_cast<_ValT>(__val) != __val)
+	    return __last;
+	  else if (!__is_constant_evaluated())
+	    {
+	      const void* __p0 = nullptr;
+	      if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>)
+		__p0 = std::__niter_base(__first);
+#if __cpp_lib_concepts
+	      else if constexpr (contiguous_iterator<_InputIterator>)
+		__p0 = std::to_address(__first);
+#endif
+
+	      if (__p0)
+		{
+		  const int __ival = static_cast<int>(__val);
+		  if (auto __n = std::distance(__first, __last); __n > 0)
+		    if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+		      return __first + ((const char*)__p1 - (const char*)__p0);
+		  return __last;
+		}
+	    }
+	}
+#endif
+
       return std::__find_if(__first, __last,
 			    __gnu_cxx::__ops::__iter_equals_val(__val));
     }
diff --git a/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
new file mode 100644
index 00000000000..2f643c4c4b4
--- /dev/null
+++ b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc
@@ -0,0 +1,135 @@ 
+// { dg-do run }
+
+#include <algorithm>
+#include <cstddef> // std::byte
+#include <testsuite_hooks.h>
+
+// PR libstdc++/88545 made std::find use memchr as an optimization.
+// This test verifies that it didn't change any semantics.
+
+template<typename C>
+void
+test_char()
+{
+  const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' };
+  const C* end = a + sizeof(a);
+  const C* res = std::find(a, end, a[0]);
+  VERIFY( res == a );
+  res = std::find(a, end, a[2]);
+  VERIFY( res == a+2 );
+  res = std::find(a, end, a[0] + 256);
+  VERIFY( res == end );
+  res = std::find(a, end, a[0] - 256);
+  VERIFY( res == end );
+  res = std::find(a, end, 256);
+  VERIFY( res == end );
+
+#ifdef __cpp_lib_ranges
+  res = std::ranges::find(a, a[0]);
+  VERIFY( res == a );
+  res = std::ranges::find(a, a[2]);
+  VERIFY( res == a+2 );
+  res = std::ranges::find(a, a[0] + 256);
+  VERIFY( res == end );
+  res = std::ranges::find(a, a[0] - 256);
+  VERIFY( res == end );
+  res = std::ranges::find(a, 256);
+  VERIFY( res == end );
+#endif
+}
+
+// Trivial type of size 1, with custom equality.
+struct S {
+  bool operator==(const S&) const { return true; };
+  char c;
+};
+
+// Trivial type of size 1, with custom equality.
+enum E
+#if __cplusplus >= 201103L
+: unsigned char
+#endif
+{ e1 = 1, e255 = 255 };
+
+bool operator==(E l, E r) { return (l % 3) == (r % 3); }
+
+struct X { char c; };
+bool operator==(X, char) { return false; }
+bool operator==(char, X) { return false; }
+
+bool operator==(E, char) { return false; }
+bool operator==(char, E) { return false; }
+
+void
+test_non_characters()
+{
+  S s[3] = { {'a'}, {'b'}, {'c'} };
+  S sx = {'x'};
+  S* sres = std::find(s, s+3, sx);
+  VERIFY( sres == s ); // memchr optimization would not find a match
+
+  E e[3] = { E(1), E(2), E(3) };
+  E* eres = std::find(e, e+3, E(4));
+  VERIFY( eres == e ); // memchr optimization would not find a match
+
+  char x[1] = { 'x' };
+  X xx = { 'x' };
+  char* xres = std::find(x, x+1, xx);
+  VERIFY( xres == x+1 ); // memchr optimization would find a match
+  xres = std::find(x, x+1, E('x'));
+  VERIFY( xres == x+1 ); // memchr optimization would find a match
+
+#ifdef __cpp_lib_byte
+  std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} };
+  std::byte* bres = std::find(b, b+4, std::byte{4});
+  VERIFY( bres == b+4 );
+  bres = std::find(b, b+2, std::byte{3});
+  VERIFY( bres == b+2 );
+  bres = std::find(b, b+3, std::byte{3});
+  VERIFY( bres == b+3 );
+#endif
+
+#ifdef __cpp_lib_ranges
+  sres = std::ranges::find(s, sx);
+  VERIFY( sres == s );
+
+  eres = std::ranges::find(e, e+3, E(4));
+  VERIFY( eres == e );
+
+  xres = std::ranges::find(x, xx);
+  VERIFY( xres == std::ranges::end(x) );
+
+  bres = std::ranges::find(b, std::byte{4});
+  VERIFY( bres == b+4 );
+  bres = std::ranges::find(b, b+2, std::byte{3});
+  VERIFY( bres == b+2 );
+  bres = std::ranges::find(b, std::byte{3});
+  VERIFY( bres == b+3 );
+
+  xres = std::find(x, x+1, xx);
+  VERIFY( xres == std::ranges::end(x) );
+  xres = std::find(x, x+1, E('x'));
+  VERIFY( xres == std::ranges::end(x) );
+#endif
+}
+
+int main()
+{
+  test_char<char>();
+  test_char<signed char>();
+  test_char<unsigned char>();
+  test_non_characters();
+
+#if __cpp_lib_constexpr_algorithms
+  static_assert( [] {
+    char c[] = "abcd";
+    return std::find(c, c+4, 'b') == c+1;
+  }() );
+#ifdef __cpp_lib_ranges
+  static_assert( [] {
+    char c[] = "abcd";
+    return std::ranges::find(c, 'b') == c+1;
+  }() );
+#endif
+#endif
+}