From patchwork Fri Aug 9 07:50:12 2019 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Alexandre Oliva X-Patchwork-Id: 1144481 Return-Path: X-Original-To: incoming@patchwork.ozlabs.org Delivered-To: patchwork-incoming@bilbo.ozlabs.org Authentication-Results: ozlabs.org; spf=pass (mailfrom) smtp.mailfrom=gcc.gnu.org (client-ip=209.132.180.131; helo=sourceware.org; envelope-from=gcc-patches-return-506561-incoming=patchwork.ozlabs.org@gcc.gnu.org; receiver=) Authentication-Results: ozlabs.org; dmarc=none (p=none dis=none) header.from=adacore.com Authentication-Results: ozlabs.org; dkim=pass (1024-bit key; unprotected) header.d=gcc.gnu.org header.i=@gcc.gnu.org header.b="at11QYdJ"; dkim-atps=neutral Received: from sourceware.org (server1.sourceware.org [209.132.180.131]) (using TLSv1.2 with cipher ECDHE-RSA-AES256-GCM-SHA384 (256/256 bits)) (No client certificate requested) by ozlabs.org (Postfix) with ESMTPS id 464cqT5xxvz9sNF for ; Fri, 9 Aug 2019 17:50:45 +1000 (AEST) DomainKey-Signature: a=rsa-sha1; c=nofws; d=gcc.gnu.org; h=list-id :list-unsubscribe:list-archive:list-post:list-help:sender:from :to:cc:subject:date:message-id:mime-version:content-type :content-transfer-encoding; q=dns; s=default; b=oWlJcEkguco4eXdE aazj5ee8DYq8REzVjCqu1Ihp71SkYVl4b90zmwmV3Dzf8McTKJNMYKzAeSSQ2kaQ aPxh5iccOr06P996vQBK8+yPkylt537/Xrk7xpTKl5+XGRqTkFUl0lEyMIVQQgHi Ng7++aUZxA4BPTdYriHquGXoCa8= DKIM-Signature: v=1; a=rsa-sha1; c=relaxed; d=gcc.gnu.org; h=list-id :list-unsubscribe:list-archive:list-post:list-help:sender:from :to:cc:subject:date:message-id:mime-version:content-type :content-transfer-encoding; s=default; bh=UD9AjQs3SS1NX5qv35aexj Iehdk=; b=at11QYdJsPx9Tzzh9z2nmDZtYHCpV4Ulmy52XhNlQhRBh9cDUy9+ku Lexo3tuTZpH0FEAFbJMDwM2K10VRU/39pD7z/HdLWdcx7SQ35cFYgDe7C8Cw16Dj JMmiKnBCSS0MyixzbQUb4Oz73I04mX6j1EcQK81KYQkA35xS4kSmc= Received: (qmail 22964 invoked by alias); 9 Aug 2019 07:50:29 -0000 Mailing-List: contact gcc-patches-help@gcc.gnu.org; run by ezmlm Precedence: bulk List-Id: List-Unsubscribe: List-Archive: List-Post: List-Help: Sender: gcc-patches-owner@gcc.gnu.org Delivered-To: mailing list gcc-patches@gcc.gnu.org Received: (qmail 22949 invoked by uid 89); 9 Aug 2019 07:50:29 -0000 Authentication-Results: sourceware.org; auth=none X-Spam-SWARE-Status: No, score=-26.9 required=5.0 tests=BAYES_00, GIT_PATCH_0, GIT_PATCH_1, GIT_PATCH_2, GIT_PATCH_3, RCVD_IN_DNSWL_NONE, SPF_PASS autolearn=ham version=3.3.1 spammy=HTo:U*drepper, que X-HELO: rock.gnat.com Received: from rock.gnat.com (HELO rock.gnat.com) (205.232.38.15) by sourceware.org (qpsmtpd/0.93/v0.84-503-g423c35a) with ESMTP; Fri, 09 Aug 2019 07:50:27 +0000 Received: from localhost (localhost.localdomain [127.0.0.1]) by filtered-rock.gnat.com (Postfix) with ESMTP id D540B56063; Fri, 9 Aug 2019 03:50:25 -0400 (EDT) Received: from rock.gnat.com ([127.0.0.1]) by localhost (rock.gnat.com [127.0.0.1]) (amavisd-new, port 10024) with LMTP id WmIOhw6SOBHz; Fri, 9 Aug 2019 03:50:25 -0400 (EDT) Received: from free.home (tron.gnat.com [IPv6:2620:20:4000:0:46a8:42ff:fe0e:e294]) (using TLSv1.2 with cipher ECDHE-RSA-AES256-GCM-SHA384 (256/256 bits)) (No client certificate requested) by rock.gnat.com (Postfix) with ESMTPS id 3D4C756061; Fri, 9 Aug 2019 03:50:24 -0400 (EDT) Received: from livre (livre.home [172.31.160.2]) by free.home (8.15.2/8.15.2) with ESMTPS id x797oCZU1413857 (version=TLSv1.3 cipher=TLS_AES_256_GCM_SHA384 bits=256 verify=NOT); Fri, 9 Aug 2019 04:50:12 -0300 From: Alexandre Oliva To: libstdc++@gcc.gnu.org, Ulrich Drepper Cc: gcc-patches@gcc.gnu.org, Corentin Gay Subject: skip Cholesky decomposition in is>>n_mv_dist Date: Fri, 09 Aug 2019 04:50:12 -0300 Message-ID: User-Agent: Gnus/5.13 (Gnus v5.13) Emacs/26.1 (gnu/linux) MIME-Version: 1.0 normal_mv_distribution maintains the variance-covariance matrix param in Cholesky-decomposed form. Existing param_type constructors, when taking a full or lower-triangle varcov matrix, perform Cholesky decomposition to convert it to the internal representation. This internal representation is visible both in the varcov() result, and in the streamed-out representation of a normal_mv_distribution object. The problem is that when that representation is streamed back in, the read-back decomposed varcov matrix is used as a lower-triangle non-decomposed varcov matrix, and it undergoes Cholesky decomposition again. So, each cycle of stream-out/stream-in changes the varcov matrix to its "square root", instead of restoring the original params. This patch includes Corentin's changes that introduce verification in testsuite/ext/random/normal_mv_distribution/operators/serialize.cc and other similar tests that the object read back in compares equal to the written-out object: the modified tests pass only if (u == v). This patch also fixes the error exposed by his change, introducing an alternate private constructor for param_type, used only by operator>>. Tested on x86_64-linux-gnu. Ok to install? for libstdc++-v3/ChangeLog * include/ext/random (normal_mv_distribution::param_type::param_type): New private ctor taking a decomposed varcov matrix, for use by... (operator>>): ... this, befriended. * include/ext/random.tcc (operator>>): Use it. (normal_mv_distribution::param_type::_M_init_lower): Adjust member function name in exception message. for libstdc++-v3/ChangeLog from Corentin Gay * testsuite/ext/random/beta_distribution/operators/serialize.cc, testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc, testsuite/ext/random/normal_mv_distribution/operators/serialize.cc, testsuite/ext/random/triangular_distribution/operators/serialize.cc, testsuite/ext/random/von_mises_distribution/operators/serialize.cc: Add call to `VERIFY`. --- libstdc++-v3/include/ext/random | 15 +++++++++++++++ libstdc++-v3/include/ext/random.tcc | 8 +++++--- .../beta_distribution/operators/serialize.cc | 2 ++ .../operators/serialize.cc | 1 + .../normal_mv_distribution/operators/serialize.cc | 2 ++ .../triangular_distribution/operators/serialize.cc | 2 ++ .../von_mises_distribution/operators/serialize.cc | 2 ++ 7 files changed, 29 insertions(+), 3 deletions(-) diff --git a/libstdc++-v3/include/ext/random b/libstdc++-v3/include/ext/random index 41a2962c8f6e5..d5574e02ba02c 100644 --- a/libstdc++-v3/include/ext/random +++ b/libstdc++-v3/include/ext/random @@ -752,6 +752,21 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION _InputIterator2 __varbegin, _InputIterator2 __varend); + // param_type constructors apply Cholesky decomposition to the + // varcov matrix in _M_init_full and _M_init_lower, but the + // varcov matrix output ot a stream is already decomposed, so + // we need means to restore it as-is when reading it back in. + template + friend std::basic_istream<_CharT, _Traits>& + operator>>(std::basic_istream<_CharT, _Traits>& __is, + __gnu_cxx::normal_mv_distribution<_Dimen1, _RealType1>& + __x); + param_type(std::array<_RealType, _Dimen> const &__mean, + std::array<_RealType, _M_t_size> const &__varcov) + : _M_mean (__mean), _M_t (__varcov) + {} + std::array<_RealType, _Dimen> _M_mean; std::array<_RealType, _M_t_size> _M_t; }; diff --git a/libstdc++-v3/include/ext/random.tcc b/libstdc++-v3/include/ext/random.tcc index 31dc33a2555ed..a8a49a3a9fa6a 100644 --- a/libstdc++-v3/include/ext/random.tcc +++ b/libstdc++-v3/include/ext/random.tcc @@ -581,7 +581,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION __sum = *__varcovbegin++ - __sum; if (__builtin_expect(__sum <= _RealType(0), 0)) std::__throw_runtime_error(__N("normal_mv_distribution::" - "param_type::_M_init_full")); + "param_type::_M_init_lower")); *__w++ = std::sqrt(__sum); } } @@ -709,9 +709,11 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION __is >> __x._M_nd; + // The param_type temporary is built with a private constructor, + // to skip the Cholesky decomposition that would be performed + // otherwise. __x.param(typename normal_mv_distribution<_Dimen, _RealType>:: - param_type(__mean.begin(), __mean.end(), - __varcov.begin(), __varcov.end())); + param_type(__mean, __varcov)); __is.flags(__flags); return __is; diff --git a/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc index b05417156d191..a4925fc1c41be 100644 --- a/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/beta_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include #include +#include void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main() diff --git a/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc index 9c2cc46ac1ce0..e9077b2c58d65 100644 --- a/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/hypergeometric_distribution/operators/serialize.cc @@ -38,6 +38,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int diff --git a/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc index 8d83f9e6966d2..f5fbc42a686f0 100644 --- a/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/normal_mv_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include #include +#include void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main() diff --git a/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc index cf17fea8b03ff..75e16cf0437cb 100644 --- a/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/triangular_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include #include +#include void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main() diff --git a/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc b/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc index f3d7912e314ba..b32a31dee6421 100644 --- a/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc +++ b/libstdc++-v3/testsuite/ext/random/von_mises_distribution/operators/serialize.cc @@ -23,6 +23,7 @@ #include #include +#include void test01() @@ -35,6 +36,7 @@ test01() str << u; str >> v; + VERIFY( u == v ); } int main()