Message ID | 20240605153233.119881-1-jwakely@redhat.com |
---|---|
State | New |
Headers | show |
Series | libstdc++: Use memchr to optimize std::find [PR88545] | expand |
This patch needs a tweak to not try to use memchr during constant evaluation, i.e. check std::is_constant_evaluated(). On Wed, 5 Jun 2024 at 16:34, Jonathan Wakely <jwakely@redhat.com> wrote: > > I plan to push this after testing finishes. > > -- >8 -- > > This optimizes std::find to use memchr when searching for an integer in > a range of bytes. > > libstdc++-v3/ChangeLog: > > PR libstdc++/88545 > PR libstdc++/115040 > * include/bits/cpp_type_traits.h (__can_use_memchr_for_find): > New variable template. > * include/bits/ranges_util.h (__find_fn): Use memchr when > possible. > * include/bits/stl_algo.h (find): Likewise. > * testsuite/25_algorithms/find/bytes.cc: New test. > --- > libstdc++-v3/include/bits/cpp_type_traits.h | 13 ++ > libstdc++-v3/include/bits/ranges_util.h | 17 +++ > libstdc++-v3/include/bits/stl_algo.h | 35 ++++++ > .../testsuite/25_algorithms/find/bytes.cc | 112 ++++++++++++++++++ > 4 files changed, 177 insertions(+) > create mode 100644 libstdc++-v3/testsuite/25_algorithms/find/bytes.cc > > diff --git a/libstdc++-v3/include/bits/cpp_type_traits.h b/libstdc++-v3/include/bits/cpp_type_traits.h > index 59f1a1875eb..466e6792a11 100644 > --- a/libstdc++-v3/include/bits/cpp_type_traits.h > +++ b/libstdc++-v3/include/bits/cpp_type_traits.h > @@ -35,6 +35,10 @@ > #pragma GCC system_header > > #include <bits/c++config.h> > +#include <bits/version.h> > +#if __glibcxx_type_trait_variable_templates > +# include <type_traits> // is_same_v, is_integral_v > +#endif > > // > // This file provides some compile-time information about various types. > @@ -589,6 +593,15 @@ __INT_N(__GLIBCXX_TYPE_INT_N_3) > { static constexpr bool __value = false; }; > #endif > > +#if __glibcxx_type_trait_variable_templates > + template<typename _ValT, typename _Tp> > + constexpr bool __can_use_memchr_for_find > + // Can only use memchr to search for narrow characters and std::byte. > + = __is_byte<_ValT>::__value > + // And only if the value to find is an integer (or is also std::byte). > + && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>); > +#endif > + > // > // Move iterator type > // > diff --git a/libstdc++-v3/include/bits/ranges_util.h b/libstdc++-v3/include/bits/ranges_util.h > index 9b79c3a229d..7247e89a79d 100644 > --- a/libstdc++-v3/include/bits/ranges_util.h > +++ b/libstdc++-v3/include/bits/ranges_util.h > @@ -34,6 +34,7 @@ > # include <bits/ranges_base.h> > # include <bits/utility.h> > # include <bits/invoke.h> > +# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find > > #ifdef __glibcxx_ranges > namespace std _GLIBCXX_VISIBILITY(default) > @@ -494,6 +495,22 @@ namespace ranges > operator()(_Iter __first, _Sent __last, > const _Tp& __value, _Proj __proj = {}) const > { > + if constexpr (is_same_v<_Proj, identity>) > + if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>) > + if constexpr (sized_sentinel_for<_Sent, _Iter>) > + if constexpr (contiguous_iterator<_Iter>) > + { > + auto __n = __last - __first; > + if (__n > 0) > + { > + const int __ival = static_cast<int>(__value); > + const void* __p0 = std::to_address(__first); > + if (auto __p1 = __builtin_memchr(__p0, __ival, __n)) > + __n = (const char*)__p1 - (const char*)__p0; > + } > + return __first + __n; > + } > + > while (__first != __last > && !(std::__invoke(__proj, *__first) == __value)) > ++__first; > diff --git a/libstdc++-v3/include/bits/stl_algo.h b/libstdc++-v3/include/bits/stl_algo.h > index 1a996aa61da..eba3157a480 100644 > --- a/libstdc++-v3/include/bits/stl_algo.h > +++ b/libstdc++-v3/include/bits/stl_algo.h > @@ -3836,6 +3836,7 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO > * such that @c *i == @p __val, or @p __last if no such iterator exists. > */ > template<typename _InputIterator, typename _Tp> > + _GLIBCXX_NODISCARD > _GLIBCXX20_CONSTEXPR > inline _InputIterator > find(_InputIterator __first, _InputIterator __last, > @@ -3846,6 +3847,40 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO > __glibcxx_function_requires(_EqualOpConcept< > typename iterator_traits<_InputIterator>::value_type, _Tp>) > __glibcxx_requires_valid_range(__first, __last); > + > +#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates > + using _ValT = typename iterator_traits<_InputIterator>::value_type; > + if constexpr (__can_use_memchr_for_find<_ValT, _Tp>) > + { > + // If converting the value to the 1-byte value_type alters its value, > + // then it would not be found by std::find using equality comparison. > + // We need to check this here, because otherwise something like > + // memchr("a", 'a'+256, 1) would give a false positive match. > + if (static_cast<_ValT>(__val) != __val) > + return __last; > + > + const void* __p0 = nullptr; > + if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>) > + __p0 = std::__niter_base(__first); > +#if __cpp_lib_concepts > + else if constexpr (contiguous_iterator<_InputIterator>) > + __p0 = std::to_address(__first); > +#endif > + > + if (__p0) > + { > + auto __n = std::distance(__first, __last); > + if (__n > 0) > + { > + const int __ival = static_cast<int>(__val); > + if (auto __p1 = __builtin_memchr(__p0, __ival, __n)) > + return __first + ((const char*)__p1 - (const char*)__p0); > + } > + return __last; > + } > + } > +#endif > + > return std::__find_if(__first, __last, > __gnu_cxx::__ops::__iter_equals_val(__val)); > } > diff --git a/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc > new file mode 100644 > index 00000000000..ac189dac65f > --- /dev/null > +++ b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc > @@ -0,0 +1,112 @@ > +// { dg-do run } > + > +#include <algorithm> > +#include <cstddef> // std::byte > +#include <testsuite_hooks.h> > + > +// PR libstdc++/88545 made std::find use memchr as an optimization. > +// This test verifies that it didn't change any semantics. > + > +template<typename C> > +void > +test_char() > +{ > + const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' }; > + const C* end = a + sizeof(a); > + const C* res = std::find(a, end, a[0]); > + VERIFY( res == a ); > + res = std::find(a, end, a[2]); > + VERIFY( res == a+2 ); > + res = std::find(a, end, a[0] + 256); > + VERIFY( res == end ); > + res = std::find(a, end, a[0] - 256); > + VERIFY( res == end ); > + res = std::find(a, end, 256); > + VERIFY( res == end ); > + > +#ifdef __cpp_lib_ranges > + res = std::ranges::find(a, a[0]); > + VERIFY( res == a ); > + res = std::ranges::find(a, a[2]); > + VERIFY( res == a+2 ); > + res = std::ranges::find(a, a[0] + 256); > + VERIFY( res == end ); > + res = std::ranges::find(a, a[0] - 256); > + VERIFY( res == end ); > + res = std::ranges::find(a, 256); > + VERIFY( res == end ); > +#endif > +} > + > +// Trivial type of size 1, with custom equality. > +struct S { > + bool operator==(const S&) const { return true; }; > + char c; > +}; > + > +// Trivial type of size 1, with custom equality. > +enum E > +#if __cplusplus >= 201103L > +: unsigned char > +#endif > +{ e1 = 1, e255 = 255 }; > + > +bool operator==(E l, E r) { return (l % 3) == (r % 3); } > + > +struct X { char c; }; > +bool operator==(X, char) { return false; } > +bool operator==(char, X) { return false; } > + > +void > +test_non_characters() > +{ > + S s[3] = { {'a'}, {'b'}, {'c'} }; > + S sx = {'x'}; > + S* sres = std::find(s, s+3, sx); > + VERIFY( sres == s ); // memchr optimization would not find a match > + > + E e[3] = { E(1), E(2), E(3) }; > + E* eres = std::find(e, e+3, E(4)); > + VERIFY( eres == e ); // memchr optimization would not find a match > + > + char x[1] = { 'x' }; > + X xx = { 'x' }; > + char* xres = std::find(x, x+1, xx); > + VERIFY( xres == x+1 ); // memchr optimization would find a match > + > +#ifdef __cpp_lib_byte > + std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} }; > + std::byte* bres = std::find(b, b+4, std::byte{4}); > + VERIFY( bres == b+4 ); > + bres = std::find(b, b+2, std::byte{3}); > + VERIFY( bres == b+2 ); > + bres = std::find(b, b+3, std::byte{3}); > + VERIFY( bres == b+3 ); > +#endif > + > +#ifdef __cpp_lib_ranges > + sres = std::ranges::find(s, sx); > + VERIFY( sres == s ); > + > + eres = std::ranges::find(e, e+3, E(4)); > + VERIFY( eres == e ); > + > + xres = std::ranges::find(x, xx); > + VERIFY( xres == std::ranges::end(x) ); > + > + bres = std::ranges::find(b, std::byte{4}); > + VERIFY( bres == b+4 ); > + bres = std::ranges::find(b, b+2, std::byte{3}); > + VERIFY( bres == b+2 ); > + bres = std::ranges::find(b, b+3, std::byte{3}); > + VERIFY( bres == b+3 ); > +#endif > +} > + > +int main() > +{ > + test_char<char>(); > + test_char<signed char>(); > + test_char<unsigned char>(); > + test_non_characters(); > +} > -- > 2.45.1 >
diff --git a/libstdc++-v3/include/bits/cpp_type_traits.h b/libstdc++-v3/include/bits/cpp_type_traits.h index 59f1a1875eb..466e6792a11 100644 --- a/libstdc++-v3/include/bits/cpp_type_traits.h +++ b/libstdc++-v3/include/bits/cpp_type_traits.h @@ -35,6 +35,10 @@ #pragma GCC system_header #include <bits/c++config.h> +#include <bits/version.h> +#if __glibcxx_type_trait_variable_templates +# include <type_traits> // is_same_v, is_integral_v +#endif // // This file provides some compile-time information about various types. @@ -589,6 +593,15 @@ __INT_N(__GLIBCXX_TYPE_INT_N_3) { static constexpr bool __value = false; }; #endif +#if __glibcxx_type_trait_variable_templates + template<typename _ValT, typename _Tp> + constexpr bool __can_use_memchr_for_find + // Can only use memchr to search for narrow characters and std::byte. + = __is_byte<_ValT>::__value + // And only if the value to find is an integer (or is also std::byte). + && (is_same_v<_Tp, _ValT> || is_integral_v<_Tp>); +#endif + // // Move iterator type // diff --git a/libstdc++-v3/include/bits/ranges_util.h b/libstdc++-v3/include/bits/ranges_util.h index 9b79c3a229d..7247e89a79d 100644 --- a/libstdc++-v3/include/bits/ranges_util.h +++ b/libstdc++-v3/include/bits/ranges_util.h @@ -34,6 +34,7 @@ # include <bits/ranges_base.h> # include <bits/utility.h> # include <bits/invoke.h> +# include <bits/cpp_type_traits.h> // __can_use_memchr_for_find #ifdef __glibcxx_ranges namespace std _GLIBCXX_VISIBILITY(default) @@ -494,6 +495,22 @@ namespace ranges operator()(_Iter __first, _Sent __last, const _Tp& __value, _Proj __proj = {}) const { + if constexpr (is_same_v<_Proj, identity>) + if constexpr(__can_use_memchr_for_find<iter_value_t<_Iter>, _Tp>) + if constexpr (sized_sentinel_for<_Sent, _Iter>) + if constexpr (contiguous_iterator<_Iter>) + { + auto __n = __last - __first; + if (__n > 0) + { + const int __ival = static_cast<int>(__value); + const void* __p0 = std::to_address(__first); + if (auto __p1 = __builtin_memchr(__p0, __ival, __n)) + __n = (const char*)__p1 - (const char*)__p0; + } + return __first + __n; + } + while (__first != __last && !(std::__invoke(__proj, *__first) == __value)) ++__first; diff --git a/libstdc++-v3/include/bits/stl_algo.h b/libstdc++-v3/include/bits/stl_algo.h index 1a996aa61da..eba3157a480 100644 --- a/libstdc++-v3/include/bits/stl_algo.h +++ b/libstdc++-v3/include/bits/stl_algo.h @@ -3836,6 +3836,7 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO * such that @c *i == @p __val, or @p __last if no such iterator exists. */ template<typename _InputIterator, typename _Tp> + _GLIBCXX_NODISCARD _GLIBCXX20_CONSTEXPR inline _InputIterator find(_InputIterator __first, _InputIterator __last, @@ -3846,6 +3847,40 @@ _GLIBCXX_BEGIN_NAMESPACE_ALGO __glibcxx_function_requires(_EqualOpConcept< typename iterator_traits<_InputIterator>::value_type, _Tp>) __glibcxx_requires_valid_range(__first, __last); + +#if __cpp_if_constexpr && __glibcxx_type_trait_variable_templates + using _ValT = typename iterator_traits<_InputIterator>::value_type; + if constexpr (__can_use_memchr_for_find<_ValT, _Tp>) + { + // If converting the value to the 1-byte value_type alters its value, + // then it would not be found by std::find using equality comparison. + // We need to check this here, because otherwise something like + // memchr("a", 'a'+256, 1) would give a false positive match. + if (static_cast<_ValT>(__val) != __val) + return __last; + + const void* __p0 = nullptr; + if constexpr (is_pointer_v<decltype(std::__niter_base(__first))>) + __p0 = std::__niter_base(__first); +#if __cpp_lib_concepts + else if constexpr (contiguous_iterator<_InputIterator>) + __p0 = std::to_address(__first); +#endif + + if (__p0) + { + auto __n = std::distance(__first, __last); + if (__n > 0) + { + const int __ival = static_cast<int>(__val); + if (auto __p1 = __builtin_memchr(__p0, __ival, __n)) + return __first + ((const char*)__p1 - (const char*)__p0); + } + return __last; + } + } +#endif + return std::__find_if(__first, __last, __gnu_cxx::__ops::__iter_equals_val(__val)); } diff --git a/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc new file mode 100644 index 00000000000..ac189dac65f --- /dev/null +++ b/libstdc++-v3/testsuite/25_algorithms/find/bytes.cc @@ -0,0 +1,112 @@ +// { dg-do run } + +#include <algorithm> +#include <cstddef> // std::byte +#include <testsuite_hooks.h> + +// PR libstdc++/88545 made std::find use memchr as an optimization. +// This test verifies that it didn't change any semantics. + +template<typename C> +void +test_char() +{ + const C a[] = { (C)'a', (C)'b', (C)'c', (C)'d' }; + const C* end = a + sizeof(a); + const C* res = std::find(a, end, a[0]); + VERIFY( res == a ); + res = std::find(a, end, a[2]); + VERIFY( res == a+2 ); + res = std::find(a, end, a[0] + 256); + VERIFY( res == end ); + res = std::find(a, end, a[0] - 256); + VERIFY( res == end ); + res = std::find(a, end, 256); + VERIFY( res == end ); + +#ifdef __cpp_lib_ranges + res = std::ranges::find(a, a[0]); + VERIFY( res == a ); + res = std::ranges::find(a, a[2]); + VERIFY( res == a+2 ); + res = std::ranges::find(a, a[0] + 256); + VERIFY( res == end ); + res = std::ranges::find(a, a[0] - 256); + VERIFY( res == end ); + res = std::ranges::find(a, 256); + VERIFY( res == end ); +#endif +} + +// Trivial type of size 1, with custom equality. +struct S { + bool operator==(const S&) const { return true; }; + char c; +}; + +// Trivial type of size 1, with custom equality. +enum E +#if __cplusplus >= 201103L +: unsigned char +#endif +{ e1 = 1, e255 = 255 }; + +bool operator==(E l, E r) { return (l % 3) == (r % 3); } + +struct X { char c; }; +bool operator==(X, char) { return false; } +bool operator==(char, X) { return false; } + +void +test_non_characters() +{ + S s[3] = { {'a'}, {'b'}, {'c'} }; + S sx = {'x'}; + S* sres = std::find(s, s+3, sx); + VERIFY( sres == s ); // memchr optimization would not find a match + + E e[3] = { E(1), E(2), E(3) }; + E* eres = std::find(e, e+3, E(4)); + VERIFY( eres == e ); // memchr optimization would not find a match + + char x[1] = { 'x' }; + X xx = { 'x' }; + char* xres = std::find(x, x+1, xx); + VERIFY( xres == x+1 ); // memchr optimization would find a match + +#ifdef __cpp_lib_byte + std::byte b[] = { std::byte{0}, std::byte{1}, std::byte{2}, std::byte{3} }; + std::byte* bres = std::find(b, b+4, std::byte{4}); + VERIFY( bres == b+4 ); + bres = std::find(b, b+2, std::byte{3}); + VERIFY( bres == b+2 ); + bres = std::find(b, b+3, std::byte{3}); + VERIFY( bres == b+3 ); +#endif + +#ifdef __cpp_lib_ranges + sres = std::ranges::find(s, sx); + VERIFY( sres == s ); + + eres = std::ranges::find(e, e+3, E(4)); + VERIFY( eres == e ); + + xres = std::ranges::find(x, xx); + VERIFY( xres == std::ranges::end(x) ); + + bres = std::ranges::find(b, std::byte{4}); + VERIFY( bres == b+4 ); + bres = std::ranges::find(b, b+2, std::byte{3}); + VERIFY( bres == b+2 ); + bres = std::ranges::find(b, b+3, std::byte{3}); + VERIFY( bres == b+3 ); +#endif +} + +int main() +{ + test_char<char>(); + test_char<signed char>(); + test_char<unsigned char>(); + test_non_characters(); +}