@@ -35,6 +35,10 @@
#pragma GCC system_header
#include <bits/c++config.h>
+#include <bits/version.h>
+#if __glibcxx_type_trait_variable_templates
+# include <type_traits> // is_same_v, is_integral_v
+#endif
//
// This file provides some compile-time information about various types.
@@ -589,6 +593,15 @@ __INT_N(__GLIBCXX_TYPE_INT_N_3)
{ static constexpr bool __value = false; };
#endif
+#if __glibcxx_type_trait_variable_templates
+ template<typename _ValT, typename _Tp>
+ constexpr bool __can_use_memchr_for_find
+ // Can only use memchr to search for narrow characters and std::byte.
+ = __is_byte<_ValT>::__value
+ // And only if the value to find is an integer (or is also std::byte).
+ && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>);
+#endif
+
//
// Move iterator type
//
@@ -34,6 +34,7 @@
# include <bits/ranges_base.h>
# include <bits/utility.h>
# include <bits/invoke.h>
+# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find
#ifdef __glibcxx_ranges
namespace std _GLIBCXX_VISIBILITY(default)
@@ -494,6 +495,23 @@ namespace ranges
operator()(_Iter __first, _Sent __last,
const _Tp& __value, _Proj __proj = {}) const
{
+ if constexpr (is_same_v<_Proj, identity>)
+ if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>)
+ if constexpr (sized_sentinel_for<_Sent, _Iter>)
+ if constexpr (contiguous_iterator<_Iter>)
+ if (!is_constant_evaluated())
+ {
+ auto __n = __last - __first;
+ if (__n > 0)
+ {
+ const int __ival = static_cast<int>(__value);
+ const void* __p0 = std::to_address(__first);
+ if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+ __n = (const char*)__p1 - (const char*)__p0;
+ }
+ return __first + __n;
+ }
+
while (__first != __last
&& !(std::__invoke(__proj, *__first) == __value))
++__first;
@@ -3846,6 +3846,39 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO
__glibcxx_function_requires(_EqualOpConcept<
typename iterator_traits<_InputIterator>::value_type, _Tp>)
__glibcxx_requires_valid_range(__first, __last);
+
+#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates
+ using _ValT = typename iterator_traits<_InputIterator>::value_type;
+ if constexpr (__can_use_memchr_for_find<_ValT, _Tp>)
+ {
+ // If converting the value to the 1-byte value_type alters its value,
+ // then it would not be found by std::find using equality comparison.
+ // We need to check this here, because otherwise something like
+ // memchr("a", 'a'+256, 1) would give a false positive match.
+ if (static_cast<_ValT>(__val) != __val)
+ return __last;
+ else if (!__is_constant_evaluated())
+ {
+ const void* __p0 = nullptr;
+ if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>)
+ __p0 = std::__niter_base(__first);
+#if __cpp_lib_concepts
+ else if constexpr (contiguous_iterator<_InputIterator>)
+ __p0 = std::to_address(__first);
+#endif
+
+ if (__p0)
+ {
+ const int __ival = static_cast<int>(__val);
+ if (auto __n = std::distance(__first, __last); __n > 0)
+ if (auto __p1 = __builtin_memchr(__p0, __ival, __n))
+ return __first + ((const char*)__p1 - (const char*)__p0);
+ return __last;
+ }
+ }
+ }
+#endif
+
return std::__find_if(__first, __last,
__gnu_cxx::__ops::__iter_equals_val(__val));
}
new file mode 100644
@@ -0,0 +1,135 @@
+// { dg-do run }
+
+#include <algorithm>
+#include <cstddef> // std::byte
+#include <testsuite_hooks.h>
+
+// PR libstdc++/88545 made std::find use memchr as an optimization.
+// This test verifies that it didn't change any semantics.
+
+template<typename C>
+void
+test_char()
+{
+ const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' };
+ const C* end = a + sizeof(a);
+ const C* res = std::find(a, end, a[0]);
+ VERIFY( res == a );
+ res = std::find(a, end, a[2]);
+ VERIFY( res == a+2 );
+ res = std::find(a, end, a[0] + 256);
+ VERIFY( res == end );
+ res = std::find(a, end, a[0] - 256);
+ VERIFY( res == end );
+ res = std::find(a, end, 256);
+ VERIFY( res == end );
+
+#ifdef __cpp_lib_ranges
+ res = std::ranges::find(a, a[0]);
+ VERIFY( res == a );
+ res = std::ranges::find(a, a[2]);
+ VERIFY( res == a+2 );
+ res = std::ranges::find(a, a[0] + 256);
+ VERIFY( res == end );
+ res = std::ranges::find(a, a[0] - 256);
+ VERIFY( res == end );
+ res = std::ranges::find(a, 256);
+ VERIFY( res == end );
+#endif
+}
+
+// Trivial type of size 1, with custom equality.
+struct S {
+ bool operator==(const S&) const { return true; };
+ char c;
+};
+
+// Trivial type of size 1, with custom equality.
+enum E
+#if __cplusplus >= 201103L
+: unsigned char
+#endif
+{ e1 = 1, e255 = 255 };
+
+bool operator==(E l, E r) { return (l % 3) == (r % 3); }
+
+struct X { char c; };
+bool operator==(X, char) { return false; }
+bool operator==(char, X) { return false; }
+
+bool operator==(E, char) { return false; }
+bool operator==(char, E) { return false; }
+
+void
+test_non_characters()
+{
+ S s[3] = { {'a'}, {'b'}, {'c'} };
+ S sx = {'x'};
+ S* sres = std::find(s, s+3, sx);
+ VERIFY( sres == s ); // memchr optimization would not find a match
+
+ E e[3] = { E(1), E(2), E(3) };
+ E* eres = std::find(e, e+3, E(4));
+ VERIFY( eres == e ); // memchr optimization would not find a match
+
+ char x[1] = { 'x' };
+ X xx = { 'x' };
+ char* xres = std::find(x, x+1, xx);
+ VERIFY( xres == x+1 ); // memchr optimization would find a match
+ xres = std::find(x, x+1, E('x'));
+ VERIFY( xres == x+1 ); // memchr optimization would find a match
+
+#ifdef __cpp_lib_byte
+ std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} };
+ std::byte* bres = std::find(b, b+4, std::byte{4});
+ VERIFY( bres == b+4 );
+ bres = std::find(b, b+2, std::byte{3});
+ VERIFY( bres == b+2 );
+ bres = std::find(b, b+3, std::byte{3});
+ VERIFY( bres == b+3 );
+#endif
+
+#ifdef __cpp_lib_ranges
+ sres = std::ranges::find(s, sx);
+ VERIFY( sres == s );
+
+ eres = std::ranges::find(e, e+3, E(4));
+ VERIFY( eres == e );
+
+ xres = std::ranges::find(x, xx);
+ VERIFY( xres == std::ranges::end(x) );
+
+ bres = std::ranges::find(b, std::byte{4});
+ VERIFY( bres == b+4 );
+ bres = std::ranges::find(b, b+2, std::byte{3});
+ VERIFY( bres == b+2 );
+ bres = std::ranges::find(b, std::byte{3});
+ VERIFY( bres == b+3 );
+
+ xres = std::find(x, x+1, xx);
+ VERIFY( xres == std::ranges::end(x) );
+ xres = std::find(x, x+1, E('x'));
+ VERIFY( xres == std::ranges::end(x) );
+#endif
+}
+
+int main()
+{
+ test_char<char>();
+ test_char<signed char>();
+ test_char<unsigned char>();
+ test_non_characters();
+
+#if __cpp_lib_constexpr_algorithms
+ static_assert( [] {
+ char c[] = "abcd";
+ return std::find(c, c+4, 'b') == c+1;
+ }() );
+#ifdef __cpp_lib_ranges
+ static_assert( [] {
+ char c[] = "abcd";
+ return std::ranges::find(c, 'b') == c+1;
+ }() );
+#endif
+#endif
+}