diff mbox series

[committed,v2] libstdc++: Fix std::codecvt<wchar_t, char, mbstate_t> for empty dest [PR37475]

Message ID 20240627110832.119192-1-jwakely@redhat.com
State New
Headers show
Series [committed,v2] libstdc++: Fix std::codecvt<wchar_t, char, mbstate_t> for empty dest [PR37475] | expand

Commit Message

Jonathan Wakely June 27, 2024, 11:07 a.m. UTC
Here's what I've pushed, with a typo fixed as spotted by Kristian in the
PR comments.

Tested x86_64-linux. Pushed to trunk.

-- >8 --

For the GNU locale model, codecvt::do_out and codecvt::do_in incorrectly
return 'ok' when the destination range is empty. That happens because
detecting incomplete output is done in the loop body, and the loop is
never even entered if to == to_end.

By restructuring the loop condition so that we check the output range
separately, we can ensure that for a non-empty source range, we always
enter the loop at least once, and detect if the destination range is too
small.

The loops also seem easier to reason about if we return immediately on
any error, instead of checking the result twice on every iteration. We
can use an RAII type to restore the locale before returning, which also
simplifies all the other member functions.

libstdc++-v3/ChangeLog:

	PR libstdc++/37475
	* config/locale/gnu/codecvt_members.cc (Guard): New RAII type.
	(do_out, do_in): Return partial if the destination is empty but
	the source is not. Use Guard to restore locale on scope exit.
	Return immediately on any conversion error.
	(do_encoding, do_max_length, do_length): Use Guard.
	* testsuite/22_locale/codecvt/in/char/37475.cc: New test.
	* testsuite/22_locale/codecvt/in/wchar_t/37475.cc: New test.
	* testsuite/22_locale/codecvt/out/char/37475.cc: New test.
	* testsuite/22_locale/codecvt/out/wchar_t/37475.cc: New test.
---
 .../config/locale/gnu/codecvt_members.cc      | 117 ++++++++----------
 .../22_locale/codecvt/in/char/37475.cc        |  23 ++++
 .../22_locale/codecvt/in/wchar_t/37475.cc     |  23 ++++
 .../22_locale/codecvt/out/char/37475.cc       |  23 ++++
 .../22_locale/codecvt/out/wchar_t/37475.cc    |  23 ++++
 5 files changed, 142 insertions(+), 67 deletions(-)
 create mode 100644 libstdc++-v3/testsuite/22_locale/codecvt/in/char/37475.cc
 create mode 100644 libstdc++-v3/testsuite/22_locale/codecvt/in/wchar_t/37475.cc
 create mode 100644 libstdc++-v3/testsuite/22_locale/codecvt/out/char/37475.cc
 create mode 100644 libstdc++-v3/testsuite/22_locale/codecvt/out/wchar_t/37475.cc
diff mbox series

Patch

diff --git a/libstdc++-v3/config/locale/gnu/codecvt_members.cc b/libstdc++-v3/config/locale/gnu/codecvt_members.cc
index 034713d236e..794f25a5f35 100644
--- a/libstdc++-v3/config/locale/gnu/codecvt_members.cc
+++ b/libstdc++-v3/config/locale/gnu/codecvt_members.cc
@@ -37,8 +37,23 @@  namespace std _GLIBCXX_VISIBILITY(default)
 {
 _GLIBCXX_BEGIN_NAMESPACE_VERSION
 
-  // Specializations.
 #ifdef _GLIBCXX_USE_WCHAR_T
+namespace
+{
+  // RAII type for changing and restoring the current thread's locale.
+  struct Guard
+  {
+#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
+    explicit Guard(__c_locale loc) : old(__uselocale(loc)) { }
+    ~Guard() { __uselocale(old); }
+#else
+    explicit Guard(__c_locale) { }
+#endif
+    __c_locale old;
+  };
+}
+
+  // Specializations.
   codecvt_base::result
   codecvt<wchar_t, char, mbstate_t>::
   do_out(state_type& __state, const intern_type* __from,
@@ -46,22 +61,21 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	 extern_type* __to, extern_type* __to_end,
 	 extern_type*& __to_next) const
   {
-    result __ret = ok;
     state_type __tmp_state(__state);
-
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
 
     // wcsnrtombs is *very* fast but stops if encounters NUL characters:
     // in case we fall back to wcrtomb and then continue, in a loop.
     // NB: wcsnrtombs is a GNU extension
-    for (__from_next = __from, __to_next = __to;
-	 __from_next < __from_end && __to_next < __to_end
-	 && __ret == ok;)
+    __from_next = __from;
+    __to_next = __to;
+    while (__from_next < __from_end)
       {
-	const intern_type* __from_chunk_end = wmemchr(__from_next, L'\0',
-						      __from_end - __from_next);
+	if (__to_next >= __to_end)
+	  return partial;
+
+	const intern_type* __from_chunk_end
+	  = wmemchr(__from_next, L'\0', __from_end - __from_next);
 	if (!__from_chunk_end)
 	  __from_chunk_end = __from_end;
 
@@ -77,12 +91,12 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	    for (; __from < __from_next; ++__from)
 	      __to_next += wcrtomb(__to_next, *__from, &__tmp_state);
 	    __state = __tmp_state;
-	    __ret = error;
+	    return error;
 	  }
 	else if (__from_next && __from_next < __from_chunk_end)
 	  {
 	    __to_next += __conv;
-	    __ret = partial;
+	    return partial;
 	  }
 	else
 	  {
@@ -90,13 +104,13 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	    __to_next += __conv;
 	  }
 
-	if (__from_next < __from_end && __ret == ok)
+	if (__from_next < __from_end)
 	  {
 	    extern_type __buf[MB_LEN_MAX];
 	    __tmp_state = __state;
 	    const size_t __conv2 = wcrtomb(__buf, *__from_next, &__tmp_state);
 	    if (__conv2 > static_cast<size_t>(__to_end - __to_next))
-	      __ret = partial;
+	      return partial;
 	    else
 	      {
 		memcpy(__to_next, __buf, __conv2);
@@ -107,11 +121,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	  }
       }
 
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-
-    return __ret;
+    return ok;
   }
 
   codecvt_base::result
@@ -121,24 +131,22 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	intern_type* __to, intern_type* __to_end,
 	intern_type*& __to_next) const
   {
-    result __ret = ok;
     state_type __tmp_state(__state);
-
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
 
     // mbsnrtowcs is *very* fast but stops if encounters NUL characters:
     // in case we store a L'\0' and then continue, in a loop.
     // NB: mbsnrtowcs is a GNU extension
-    for (__from_next = __from, __to_next = __to;
-	 __from_next < __from_end && __to_next < __to_end
-	 && __ret == ok;)
+    __from_next = __from;
+    __to_next = __to;
+    while (__from_next < __from_end)
       {
-	const extern_type* __from_chunk_end;
-	__from_chunk_end = static_cast<const extern_type*>(memchr(__from_next, '\0',
-								  __from_end
-								  - __from_next));
+	if (__to_next >= __to_end)
+	  return partial;
+
+	const extern_type* __from_chunk_end
+	  = static_cast<const extern_type*>(memchr(__from_next, '\0',
+						   __from_end - __from_next));
 	if (!__from_chunk_end)
 	  __from_chunk_end = __from_end;
 
@@ -161,13 +169,13 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	      }
 	    __from_next = __from;
 	    __state = __tmp_state;
-	    __ret = error;
+	    return error;
 	  }
 	else if (__from_next && __from_next < __from_chunk_end)
 	  {
 	    // It is unclear what to return in this case (see DR 382).
 	    __to_next += __conv;
-	    __ret = partial;
+	    return partial;
 	  }
 	else
 	  {
@@ -175,7 +183,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	    __to_next += __conv;
 	  }
 
-	if (__from_next < __from_end && __ret == ok)
+	if (__from_next < __from_end)
 	  {
 	    if (__to_next < __to_end)
 	      {
@@ -185,48 +193,30 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 		*__to_next++ = L'\0';
 	      }
 	    else
-	      __ret = partial;
+	      return partial;
 	  }
       }
 
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-
-    return __ret;
+    return ok;
   }
 
   int
   codecvt<wchar_t, char, mbstate_t>::
   do_encoding() const throw()
   {
+    Guard g(_M_c_locale_codecvt);
     // XXX This implementation assumes that the encoding is
     // stateless and is either single-byte or variable-width.
-    int __ret = 0;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
-    if (MB_CUR_MAX == 1)
-      __ret = 1;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-    return __ret;
+    return MB_CUR_MAX == 1;
   }
 
   int
   codecvt<wchar_t, char, mbstate_t>::
   do_max_length() const throw()
   {
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
     // XXX Probably wrong for stateful encodings.
-    int __ret = MB_CUR_MAX;
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-    return __ret;
+    return MB_CUR_MAX;
   }
 
   int
@@ -236,10 +226,7 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
   {
     int __ret = 0;
     state_type __tmp_state(__state);
-
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __c_locale __old = __uselocale(_M_c_locale_codecvt);
-#endif
+    Guard g(_M_c_locale_codecvt);
 
     // mbsnrtowcs is *very* fast but stops if encounters NUL characters:
     // in case we advance past it and then continue, in a loop.
@@ -295,10 +282,6 @@  _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	  }
       }
 
-#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 2)
-    __uselocale(__old);
-#endif
-
     return __ret;
   }
 #endif
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/in/char/37475.cc b/libstdc++-v3/testsuite/22_locale/codecvt/in/char/37475.cc
new file mode 100644
index 00000000000..6184c3280cb
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/in/char/37475.cc
@@ -0,0 +1,23 @@ 
+#include <locale>
+#include <testsuite_hooks.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<char, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const char from = 'a';
+  const char* from_next;
+  char to = 0;
+  char* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.in(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  VERIFY( res == std::codecvt_base::noconv );
+}
+
+int main()
+{
+  test_pr37475();
+}
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/in/wchar_t/37475.cc b/libstdc++-v3/testsuite/22_locale/codecvt/in/wchar_t/37475.cc
new file mode 100644
index 00000000000..a0e64847ea9
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/in/wchar_t/37475.cc
@@ -0,0 +1,23 @@ 
+#include <locale>
+#include <testsuite_hooks.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<wchar_t, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const char from = 'a';
+  const char* from_next;
+  wchar_t to = 0;
+  wchar_t* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.in(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  VERIFY( res == std::codecvt_base::partial );
+}
+
+int main()
+{
+  test_pr37475();
+}
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/out/char/37475.cc b/libstdc++-v3/testsuite/22_locale/codecvt/out/char/37475.cc
new file mode 100644
index 00000000000..8736e4b7f3f
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/out/char/37475.cc
@@ -0,0 +1,23 @@ 
+#include <locale>
+#include <assert.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<char, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const char from = 'a';
+  const char* from_next;
+  char to;
+  char* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.out(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  assert( res == std::codecvt_base::noconv );
+}
+
+int main()
+{
+  test_pr37475();
+}
diff --git a/libstdc++-v3/testsuite/22_locale/codecvt/out/wchar_t/37475.cc b/libstdc++-v3/testsuite/22_locale/codecvt/out/wchar_t/37475.cc
new file mode 100644
index 00000000000..2cd2edb7404
--- /dev/null
+++ b/libstdc++-v3/testsuite/22_locale/codecvt/out/wchar_t/37475.cc
@@ -0,0 +1,23 @@ 
+#include <locale>
+#include <assert.h>
+
+void
+test_pr37475()
+{
+  typedef std::codecvt<wchar_t, char, std::mbstate_t> test_type;
+  const test_type& cvt = std::use_facet<test_type>(std::locale::classic());
+  const wchar_t from = L'a';
+  const wchar_t* from_next;
+  char to;
+  char* to_next;
+  std::mbstate_t st = std::mbstate_t();
+  std::codecvt_base::result res
+    = cvt.out(st, &from, &from+1, from_next, &to, &to, to_next);
+
+  assert( res == std::codecvt_base::partial );
+}
+
+int main()
+{
+  test_pr37475();
+}