From patchwork Thu Aug 1 14:56:49 2024 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Arthur Cohen X-Patchwork-Id: 1967786 Return-Path: X-Original-To: incoming@patchwork.ozlabs.org Delivered-To: patchwork-incoming@legolas.ozlabs.org Authentication-Results: legolas.ozlabs.org; dkim=pass (2048-bit key; unprotected) header.d=embecosm.com header.i=@embecosm.com header.a=rsa-sha256 header.s=google header.b=VO2lRT81; dkim-atps=neutral Authentication-Results: legolas.ozlabs.org; spf=pass (sender SPF authorized) smtp.mailfrom=gcc.gnu.org (client-ip=2620:52:3:1:0:246e:9693:128c; helo=server2.sourceware.org; envelope-from=gcc-patches-bounces~incoming=patchwork.ozlabs.org@gcc.gnu.org; receiver=patchwork.ozlabs.org) Received: from server2.sourceware.org (server2.sourceware.org [IPv6:2620:52:3:1:0:246e:9693:128c]) (using TLSv1.3 with cipher TLS_AES_256_GCM_SHA384 (256/256 bits) key-exchange X25519 server-signature ECDSA (secp384r1) server-digest SHA384) (No client certificate requested) by legolas.ozlabs.org (Postfix) with ESMTPS id 4WZXm81VQZz1ybX for ; Fri, 2 Aug 2024 01:23:20 +1000 (AEST) Received: from server2.sourceware.org (localhost [IPv6:::1]) by sourceware.org (Postfix) with ESMTP id 777793860769 for ; Thu, 1 Aug 2024 15:23:18 +0000 (GMT) X-Original-To: gcc-patches@gcc.gnu.org Delivered-To: gcc-patches@gcc.gnu.org Received: from mail-ed1-x52a.google.com (mail-ed1-x52a.google.com [IPv6:2a00:1450:4864:20::52a]) by sourceware.org (Postfix) with ESMTPS id 231573860C2E for ; Thu, 1 Aug 2024 14:59:15 +0000 (GMT) DMARC-Filter: OpenDMARC Filter v1.4.2 sourceware.org 231573860C2E Authentication-Results: sourceware.org; dmarc=none (p=none dis=none) header.from=embecosm.com Authentication-Results: sourceware.org; spf=pass smtp.mailfrom=embecosm.com ARC-Filter: OpenARC Filter v1.0.0 sourceware.org 231573860C2E Authentication-Results: server2.sourceware.org; arc=none smtp.remote-ip=2a00:1450:4864:20::52a ARC-Seal: i=1; a=rsa-sha256; d=sourceware.org; s=key; t=1722524395; cv=none; b=RkLFbY7YY69jVyrv3vsSkDb2IinXZs5fAce10jvkbdHvS69wm6l3Xiug1XBRLAt12coQ1caXZF+Orwtiz5pdLhEuuItJlRM1gas+cIHCS+p0hCZ/stq24MF3te6L6OmT63KNd0i0soL0QxeHqxRpM6Q/plNGLbxDn+KNiREpurc= ARC-Message-Signature: i=1; a=rsa-sha256; d=sourceware.org; s=key; t=1722524395; c=relaxed/simple; bh=rNlLxh5Wa+jiFBXdcZ6XX5tyE9Pq0Lb1Uylun5Y0moc=; h=DKIM-Signature:From:To:Subject:Date:Message-ID:MIME-Version; b=X8hFX0DHGabxFaMW9oLaeLUGsqS7dVUArcsbaqhDH1n4W7XYKpQ0ZnIh3EYUyhpG+FtaxsFe/XWG90Tdr2JUsB2oXC0mODE7L8vgn4+VnpfLEpcLAQuVZdX++F18P25UhI2+5Jesxr7uIIKj6Dblt8gDMGmgF6MoXG8dtFCvduU= ARC-Authentication-Results: i=1; server2.sourceware.org Received: by mail-ed1-x52a.google.com with SMTP id 4fb4d7f45d1cf-5b01af9b0c9so6166956a12.3 for ; Thu, 01 Aug 2024 07:59:15 -0700 (PDT) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=embecosm.com; s=google; t=1722524353; x=1723129153; darn=gcc.gnu.org; h=content-transfer-encoding:mime-version:references:in-reply-to :message-id:date:subject:cc:to:from:from:to:cc:subject:date :message-id:reply-to; bh=q42QOzjyahHuu77/bEMZ7IolDhrUNmVahc7DETaZpY8=; b=VO2lRT81MkdQcY5I1Z4rID6xGs9js3BUaRf8L+oZaOuS0WkiJcpkTlGf5i7skiapBM GBfcHVEuCS6Y4nOSDJvyUCSDlCmkeF+pI0x9QR2rtl/vvxAGa1LXgO7UeBNj0MYH/Bl5 oot27PYn4jC1fqhHlv7+nOhYj0PzU5mspnGqOXh037xsmFxz40n9+aZ+quFdSnhdfbTU NqfOKBYglbxO6FOXwXbo6VidHrGVdyh3njXuPYbWdlzrL3ZUEUYshz51/kQDaiAnjdMt AkuATOnGBueyMFYMiQk0dgDl+MC4aACBXq9YnVwt2HsZcX4O7GHgUOm9f+qOZd7zJg3p jmfw== X-Google-DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; t=1722524353; x=1723129153; h=content-transfer-encoding:mime-version:references:in-reply-to :message-id:date:subject:cc:to:from:x-gm-message-state:from:to:cc :subject:date:message-id:reply-to; bh=q42QOzjyahHuu77/bEMZ7IolDhrUNmVahc7DETaZpY8=; b=Ktl/97IHGzUBC0TvoEmfDtRlQEuNAT4TGljKPUC+aXzudwW1bgT4AcY/LAzfddidfU 7gEDAkzDvAvdq7Ztf8tUd187PPBR8A0OHpw6MDKV5WkyMBkCo41R/VgdCUuBLPpv69v+ Ocf5q9Rr5uyUY16QVl6a2GC486Czo3SWug6PgGqrbbGrtVvm0Q5wHB3eoG5Jpzl+Za3M lhrZKNFOokyfM/Yj+9kMez4S7NsJ4/SJGPLAsv2csqX4srwPnDb4BWFFhR6IHwRGa7p8 XkDorHhHopqq5w2OHy35tJpXUfpcUUZlemIUS32iiYOJSabRhvHY4Isdj6I+MK5ROVL5 HCqg== X-Gm-Message-State: AOJu0YzOLv+E44+01BA/P5xnBAEsv0bZAKaOc7gQoHs1itNVIjVKZ2Sq k2MLv1hiGoo4Gyi0WfWXAJ5RHo0tsKuNU88ip4GRUoORdRM85FhEdMN23f51izOt2/M6CBPIb6O HRr15 X-Google-Smtp-Source: AGHT+IEZeKPNDoF2J4MrqIzpO+VG81XR2VbndPqbuym2qqYsKQGQZiTp568XabAXaVS9NCyRpc1xvg== X-Received: by 2002:aa7:c50e:0:b0:5a4:6dec:cd41 with SMTP id 4fb4d7f45d1cf-5b7f57f3a8dmr427752a12.28.1722524353281; Thu, 01 Aug 2024 07:59:13 -0700 (PDT) Received: from platypus.lan ([2a04:cec2:9:dc84:3622:6733:ff49:ee91]) by smtp.gmail.com with ESMTPSA id 4fb4d7f45d1cf-5ac63590592sm10252456a12.25.2024.08.01.07.59.12 (version=TLS1_3 cipher=TLS_AES_256_GCM_SHA384 bits=256/256); Thu, 01 Aug 2024 07:59:12 -0700 (PDT) From: Arthur Cohen To: gcc-patches@gcc.gnu.org Cc: gcc-rust@gcc.gnu.org, Jakub Dupak Subject: [PATCH 053/125] gccrs: TyTy: Variance analysis module Date: Thu, 1 Aug 2024 16:56:49 +0200 Message-ID: <20240801145809.366388-55-arthur.cohen@embecosm.com> X-Mailer: git-send-email 2.45.2 In-Reply-To: <20240801145809.366388-2-arthur.cohen@embecosm.com> References: <20240801145809.366388-2-arthur.cohen@embecosm.com> MIME-Version: 1.0 X-Spam-Status: No, score=-14.1 required=5.0 tests=BAYES_00, DKIM_SIGNED, DKIM_VALID, DKIM_VALID_AU, DKIM_VALID_EF, GIT_PATCH_0, RCVD_IN_DNSWL_NONE, SPF_HELO_NONE, SPF_PASS, TXREP autolearn=unavailable autolearn_force=no version=3.4.6 X-Spam-Checker-Version: SpamAssassin 3.4.6 (2021-04-09) on server2.sourceware.org X-BeenThere: gcc-patches@gcc.gnu.org X-Mailman-Version: 2.1.30 Precedence: list List-Id: Gcc-patches mailing list List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Errors-To: gcc-patches-bounces~incoming=patchwork.ozlabs.org@gcc.gnu.org From: Jakub Dupak gcc/rust/ChangeLog: * Make-lang.in: Add new .cc file. * rust-session-manager.cc (Session::compile_crate): Run analysis. * typecheck/rust-tyty-variance-analysis-private.h: New file. * typecheck/rust-tyty-variance-analysis.cc: New file. * typecheck/rust-tyty-variance-analysis.h: New file. * typecheck/rust-typecheck-context.cc (TypeCheckContext::get_variance_analysis_ctx): Variance analysis context. * typecheck/rust-hir-type-check.h (TypeCheckItem::visit): Variance analysis context. Signed-off-by: Jakub Dupak --- gcc/rust/Make-lang.in | 1 + gcc/rust/rust-session-manager.cc | 3 + gcc/rust/typecheck/rust-hir-type-check.h | 6 + gcc/rust/typecheck/rust-typecheck-context.cc | 6 + .../rust-tyty-variance-analysis-private.h | 304 ++++++++++ .../typecheck/rust-tyty-variance-analysis.cc | 541 ++++++++++++++++++ .../typecheck/rust-tyty-variance-analysis.h | 114 ++++ 7 files changed, 975 insertions(+) create mode 100644 gcc/rust/typecheck/rust-tyty-variance-analysis-private.h create mode 100644 gcc/rust/typecheck/rust-tyty-variance-analysis.cc create mode 100644 gcc/rust/typecheck/rust-tyty-variance-analysis.h diff --git a/gcc/rust/Make-lang.in b/gcc/rust/Make-lang.in index cbb9da0fe43..67df843349f 100644 --- a/gcc/rust/Make-lang.in +++ b/gcc/rust/Make-lang.in @@ -147,6 +147,7 @@ GRS_OBJS = \ rust/rust-tyty-util.o \ rust/rust-tyty-call.o \ rust/rust-tyty-subst.o \ + rust/rust-tyty-variance-analysis.o \ rust/rust-typecheck-context.o \ rust/rust-tyty-bounds.o \ rust/rust-hir-trait-resolve.o \ diff --git a/gcc/rust/rust-session-manager.cc b/gcc/rust/rust-session-manager.cc index 62c47b2e6de..ea99d019f64 100644 --- a/gcc/rust/rust-session-manager.cc +++ b/gcc/rust/rust-session-manager.cc @@ -48,6 +48,7 @@ #include "rust-attribute-values.h" #include "rust-borrow-checker.h" #include "rust-ast-validation.h" +#include "rust-tyty-variance-analysis.h" #include "input.h" #include "selftest.h" @@ -652,6 +653,8 @@ Session::compile_crate (const char *filename) // type resolve Resolver::TypeResolution::Resolve (hir); + Resolver::TypeCheckContext::get ()->get_variance_analysis_ctx ().solve (); + if (saw_errors ()) return; diff --git a/gcc/rust/typecheck/rust-hir-type-check.h b/gcc/rust/typecheck/rust-hir-type-check.h index 52c84fc4435..c32fa4e8487 100644 --- a/gcc/rust/typecheck/rust-hir-type-check.h +++ b/gcc/rust/typecheck/rust-hir-type-check.h @@ -24,6 +24,7 @@ #include "rust-hir-trait-reference.h" #include "rust-autoderef.h" #include "rust-tyty-region.h" +#include "rust-tyty-variance-analysis.h" #include @@ -233,6 +234,8 @@ public: void compute_inference_variables (bool error); + TyTy::VarianceAnalysis::CrateCtx &get_variance_analysis_ctx (); + private: TypeCheckContext (); @@ -272,6 +275,9 @@ private: std::set querys_in_progress; std::set trait_queries_in_progress; + // variance analysis + TyTy::VarianceAnalysis::CrateCtx variance_analysis_ctx; + /** Used to resolve (interned) lifetime names to their bounding scope. */ class LifetimeResolver { diff --git a/gcc/rust/typecheck/rust-typecheck-context.cc b/gcc/rust/typecheck/rust-typecheck-context.cc index 9059e0261b3..ab0093a273b 100644 --- a/gcc/rust/typecheck/rust-typecheck-context.cc +++ b/gcc/rust/typecheck/rust-typecheck-context.cc @@ -617,6 +617,12 @@ TypeCheckContext::compute_inference_variables (bool error) }); } +TyTy::VarianceAnalysis::CrateCtx & +TypeCheckContext::get_variance_analysis_ctx () +{ + return variance_analysis_ctx; +} + // TypeCheckContextItem TypeCheckContextItem::Item::Item (HIR::Function *item) : item (item) {} diff --git a/gcc/rust/typecheck/rust-tyty-variance-analysis-private.h b/gcc/rust/typecheck/rust-tyty-variance-analysis-private.h new file mode 100644 index 00000000000..ab8c039238e --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-variance-analysis-private.h @@ -0,0 +1,304 @@ +#ifndef RUST_TYTY_VARIANCE_ANALYSIS_PRIVATE_H +#define RUST_TYTY_VARIANCE_ANALYSIS_PRIVATE_H + +#include "rust-tyty-variance-analysis.h" + +#include "rust-tyty-visitor.h" + +namespace Rust { +namespace TyTy { +namespace VarianceAnalysis { + +using SolutionIndex = uint32_t; + +/** Term descibing variance relations. */ +struct Term +{ + enum Kind : uint8_t + { + CONST, + REF, + TRANSFORM, + }; + + Kind kind; + union + { + struct + { + Term *lhs; + Term *rhs; + } transform; + SolutionIndex ref; + Variance const_val; + }; + + Term () {} + + Term (Variance variance) : kind (CONST), const_val (variance) {} + + WARN_UNUSED_RESULT bool is_const () const { return kind == CONST; } + + static Term make_ref (SolutionIndex index); + + static Term make_transform (Term lhs, Term rhs); +}; + +/** Variance constraint of a type parameter. */ +struct Constraint +{ + SolutionIndex target_index; + Term *term; +}; + +/** Abstract variance visitor context. */ +template class VarianceVisitorCtx +{ +public: + virtual ~VarianceVisitorCtx () = default; + + virtual void add_constraints_from_ty (BaseType *ty, VARIANCE variance) = 0; + virtual void add_constraints_from_region (const Region ®ion, + VARIANCE variance) + = 0; + void add_constraints_from_mutability (BaseType *type, Mutability mutability, + VARIANCE variance) + { + switch (mutability) + { + case Mutability::Imm: + return add_constraints_from_ty (type, variance); + case Mutability::Mut: + return add_constraints_from_ty (type, Variance::invariant ()); + } + } + virtual void + add_constraints_from_generic_args (HirId ref, SubstitutionRef &subst, + VARIANCE variance, bool invariant_args) + = 0; + virtual void add_constrints_from_param (ParamType ¶m, VARIANCE variance) + = 0; + virtual VARIANCE contra (VARIANCE variance) = 0; +}; + +template class VisitorBase final : public TyVisitor +{ + VarianceVisitorCtx &ctx; + VARIANCE variance; + +public: + VisitorBase (VarianceVisitorCtx &ctx, VARIANCE variance) + : ctx (ctx), variance (variance) + {} + + void visit (BoolType &type) override {} + void visit (CharType &type) override {} + void visit (IntType &type) override {} + void visit (UintType &type) override {} + void visit (FloatType &type) override {} + void visit (USizeType &type) override {} + void visit (ISizeType &type) override {} + void visit (StrType &type) override {} + void visit (NeverType &type) override {} + + void visit (ClosureType &type) override {} + void visit (FnType &type) override + { + for (auto ®ion : type.get_used_arguments ().get_regions ()) + ctx.add_constraints_from_region (region, Variance::invariant ()); + } + + void visit (ReferenceType &type) override + { + ctx.add_constraints_from_region (type.get_region (), variance); + ctx.add_constraints_from_mutability (type.get_base (), type.mutability (), + variance); + } + void visit (ArrayType &type) override + { + ctx.add_constraints_from_ty (type.get_element_type (), variance); + } + void visit (SliceType &type) override + { + ctx.add_constraints_from_ty (type.get_element_type (), variance); + } + void visit (PointerType &type) override + { + ctx.add_constraints_from_ty (type.get_base (), variance); + ctx.add_constraints_from_mutability (type.get_base (), type.mutability (), + variance); + } + void visit (TupleType &type) override + { + for (auto &elem : type.get_fields ()) + ctx.add_constraints_from_ty (elem.get_tyty (), variance); + } + void visit (ADTType &type) override + { + ctx.add_constraints_from_generic_args (type.get_orig_ref (), type, variance, + false); + } + void visit (ProjectionType &type) override + { + ctx.add_constraints_from_generic_args (type.get_orig_ref (), type, variance, + true); + } + void visit (ParamType &type) override + { + ctx.add_constrints_from_param (type, variance); + } + void visit (FnPtr &type) override + { + auto contra = ctx.contra (variance); + + for (auto ¶m : type.get_params ()) + { + ctx.add_constraints_from_ty (param.get_tyty (), contra); + } + + ctx.add_constraints_from_ty (type.get_return_type (), variance); + } + + void visit (ErrorType &type) override {} + + void visit (PlaceholderType &type) override { rust_unreachable (); } + void visit (InferType &type) override { rust_unreachable (); } + + void visit (DynamicObjectType &type) override + { + // TODO + } +}; + +/** Per crate context for generic type variance analysis. */ +class GenericTyPerCrateCtx +{ +public: // External API + /** Add a type to context and process its variance constraints. */ + void process_type (ADTType &ty); + + /** + * Solve for all variance constraints and clear temporary data. + * + * Only keeps the results. + */ + void solve (); + + /** Prints solution debug output. To be called after solve. */ + void debug_print_solutions (); + + tl::optional lookup_type_index (HirId orig_ref); + +public: // Module internal API + /** Format term tree to string. */ + WARN_UNUSED_RESULT std::string to_string (const Term &term) const; + + /** Formats as `[```]` */ + WARN_UNUSED_RESULT std::string to_string (SolutionIndex index) const; + + /** Evaluate a variance relation expression (term tree). */ + Variance evaluate (Term *term); + + std::vector query_generic_variance (const ADTType &type); + +public: // Data used by visitors. + // This whole class is private, therfore members can be public. + + /** Current solutions. Initiated to bivariant. */ + std::vector solutions; + + /** Constrains on solutions. Iteratively applied until fixpoint. */ + std::vector constraints; + + /** Maps TyTy::orig_ref to an index of first solution for this type. */ + std::unordered_map map_from_ty_orig_ref; +}; + +/** Visitor context for generic type variance analysis used for processing of a + * single type. */ +class GenericTyVisitorCtx : VarianceVisitorCtx +{ + using Visitor = VisitorBase; + +public: + explicit GenericTyVisitorCtx (GenericTyPerCrateCtx &ctx) : ctx (ctx) {} + /** Entry point: Add a type to context and process its variance constraints. + */ + void process_type (ADTType &ty); + +private: + /** Resolve a type from a TyTy::ref. */ + SolutionIndex lookup_or_add_type (HirId hir_id); + + /** Visit an inner type and add its constraints. */ + void add_constraints_from_ty (BaseType *ty, Term variance) override; + + void add_constraint (SolutionIndex index, Term term); + + void add_constraints_from_region (const Region ®ion, Term term) override; + + void add_constraints_from_generic_args (HirId ref, SubstitutionRef &subst, + Term variance, + bool invariant_args) override; + + void add_constrints_from_param (ParamType &type, Term variance) override; + + /** Construct a term for type in contravaraint position. */ + Term contra (Term variance) override; + +private: + GenericTyPerCrateCtx &ctx; + +private: // Per type processing context + /** Index of the solution first **lifetime param** for the current type. */ + SolutionIndex first_lifetime = 0; + + /** Index of the solution first **type param** for the current type. */ + SolutionIndex first_type = 0; + + /** Maps type param names to index among type params. */ + std::vector param_names; +}; + +/** Visitor context for basic type variance analysis. */ +class TyVisitorCtx : public VarianceVisitorCtx +{ +public: + using Visitor = VisitorBase; + + TyVisitorCtx (GenericTyPerCrateCtx &ctx) : ctx (ctx) {} + + std::vector collect_variances (BaseType &ty) + { + add_constraints_from_ty (&ty, Variance::covariant ()); + return variances; + } + + std::vector collect_regions (BaseType &ty) + { + add_constraints_from_ty (&ty, Variance::covariant ()); + return regions; + } + + void add_constraints_from_ty (BaseType *ty, Variance variance) override; + void add_constraints_from_region (const Region ®ion, + Variance variance) override; + void add_constraints_from_generic_args (HirId ref, SubstitutionRef &subst, + Variance variance, + bool invariant_args) override; + void add_constrints_from_param (ParamType ¶m, Variance variance) override + {} + Variance contra (Variance variance) override; + +private: + GenericTyPerCrateCtx &ctx; + std::vector variances; + std::vector regions; +}; + +} // namespace VarianceAnalysis + +} // namespace TyTy +} // namespace Rust + +#endif // RUST_TYTY_VARIANCE_ANALYSIS_PRIVATE_H diff --git a/gcc/rust/typecheck/rust-tyty-variance-analysis.cc b/gcc/rust/typecheck/rust-tyty-variance-analysis.cc new file mode 100644 index 00000000000..5a21d69651b --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-variance-analysis.cc @@ -0,0 +1,541 @@ +#include "rust-tyty-variance-analysis-private.h" +#include "rust-hir-type-check.h" + +namespace Rust { +namespace TyTy { + +BaseType * +lookup_type (HirId ref) +{ + BaseType *ty = nullptr; + bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty); + rust_assert (ok); + return ty; +} + +namespace VarianceAnalysis { + +CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {} + +// Must be here because of incomplete type. +CrateCtx::~CrateCtx () = default; + +void +CrateCtx::add_type_constraints (ADTType &type) +{ + private_ctx->process_type (type); +} + +void +CrateCtx::solve () +{ + private_ctx->solve (); + private_ctx->debug_print_solutions (); +} + +std::vector +CrateCtx::query_generic_variance (const ADTType &type) +{ + return private_ctx->query_generic_variance (type); +} + +std::vector +CrateCtx::query_type_variances (BaseType *type) +{ + TyVisitorCtx ctx (*private_ctx); + return ctx.collect_variances (*type); +} + +std::vector +CrateCtx::query_type_regions (BaseType *type) +{ + TyVisitorCtx ctx (*private_ctx); + return ctx.collect_regions (*type); +} + +Variance +Variance::reverse () const +{ + switch (kind) + { + case BIVARIANT: + return bivariant (); + case COVARIANT: + return contravariant (); + case CONTRAVARIANT: + return covariant (); + case INVARIANT: + return invariant (); + } + + rust_unreachable (); +} + +Variance +Variance::join (Variance lhs, Variance rhs) +{ + return {Kind (lhs.kind | rhs.kind)}; +} + +void +Variance::join (Variance rhs) +{ + *this = join (*this, rhs); +} + +Variance +Variance::transform (Variance lhs, Variance rhs) +{ + switch (lhs.kind) + { + case BIVARIANT: + return bivariant (); + case COVARIANT: + return rhs; + case CONTRAVARIANT: + return rhs.reverse (); + case INVARIANT: + return invariant (); + } + rust_unreachable (); +} + +std::string +Variance::as_string () const +{ + switch (kind) + { + case BIVARIANT: + return "o"; + case COVARIANT: + return "+"; + case CONTRAVARIANT: + return "-"; + case INVARIANT: + return "*"; + } + rust_unreachable (); +} + +void +GenericTyPerCrateCtx::process_type (ADTType &type) +{ + GenericTyVisitorCtx (*this).process_type (type); +} + +void +GenericTyPerCrateCtx::solve () +{ + rust_debug ("Variance analysis solving started:"); + + // Fix point iteration + bool changed = true; + while (changed) + { + changed = false; + for (auto constraint : constraints) + { + rust_debug ("\tapplying constraint: %s <= %s", + to_string (constraint.target_index).c_str (), + to_string (*constraint.term).c_str ()); + + auto old_solution = solutions[constraint.target_index]; + auto new_solution + = Variance::join (old_solution, evaluate (constraint.term)); + + if (old_solution != new_solution) + { + rust_debug ("\t\tsolution changed: %s => %s", + old_solution.as_string ().c_str (), + new_solution.as_string ().c_str ()); + + changed = true; + solutions[constraint.target_index] = new_solution; + } + } + } + + constraints.clear (); + constraints.shrink_to_fit (); +} + +void +GenericTyPerCrateCtx::debug_print_solutions () +{ + rust_debug ("Variance analysis results:"); + + for (auto type : map_from_ty_orig_ref) + { + auto solution_index = type.second; + auto ref = type.first; + + BaseType *ty = lookup_type (ref); + + std::string result = "\t"; + + if (auto adt = ty->try_as ()) + { + result += adt->get_identifier (); + result += "<"; + + size_t i = solution_index; + for (auto ®ion : adt->get_used_arguments ().get_regions ()) + { + (void) region; + if (i > solution_index) + result += ", "; + result += solutions[i].as_string (); + i++; + } + for (auto ¶m : adt->get_substs ()) + { + if (i > solution_index) + result += ", "; + result += param.get_generic_param () + .get_type_representation () + .as_string (); + result += "="; + result += solutions[i].as_string (); + i++; + } + + result += ">"; + } + else + { + rust_sorry_at ( + ty->get_ref (), + "This is a compiler bug: Unhandled type in variance analysis"); + } + rust_debug ("%s", result.c_str ()); + } +} + +tl::optional +GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref) +{ + auto it = map_from_ty_orig_ref.find (orig_ref); + if (it != map_from_ty_orig_ref.end ()) + { + return it->second; + } + return tl::nullopt; +} + +void +GenericTyVisitorCtx::process_type (ADTType &ty) +{ + rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ()); + + first_lifetime = lookup_or_add_type (ty.get_orig_ref ()); + first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size (); + + for (const auto ¶m : ty.get_substs ()) + param_names.push_back ( + param.get_generic_param ().get_type_representation ().as_string ()); + + for (const auto &variant : ty.get_variants ()) + { + if (variant->get_variant_type () != VariantDef::NUM) + { + for (const auto &field : variant->get_fields ()) + add_constraints_from_ty (field->get_field_type (), + Variance::covariant ()); + } + } +} + +std::string +GenericTyPerCrateCtx::to_string (const Term &term) const +{ + switch (term.kind) + { + case Term::CONST: + return term.const_val.as_string (); + case Term::REF: + return "v(" + to_string (term.ref) + ")"; + case Term::TRANSFORM: + return "(" + to_string (*term.transform.lhs) + " x " + + to_string (*term.transform.rhs) + ")"; + } + rust_unreachable (); +} + +std::string +GenericTyPerCrateCtx::to_string (SolutionIndex index) const +{ + // Search all values in def_id_to_solution_index_start and find key for + // largest value smaller than index + std::pair best = {0, 0}; + + for (const auto &ty_map : map_from_ty_orig_ref) + { + if (ty_map.second <= index && ty_map.first > best.first) + best = ty_map; + } + rust_assert (best.first != 0); + + BaseType *ty = lookup_type (best.first); + + std::string result = ""; + if (auto adt = ty->try_as ()) + { + result += (adt->get_identifier ()); + } + else + { + result += ty->as_string (); + } + + result += "[" + std::to_string (index - best.first) + "]"; + return result; +} + +Variance +GenericTyPerCrateCtx::evaluate (Term *term) +{ + switch (term->kind) + { + case Term::CONST: + return term->const_val; + case Term::REF: + return solutions[term->ref]; + case Term::TRANSFORM: + return Variance::transform (evaluate (term->transform.lhs), + evaluate (term->transform.rhs)); + } + rust_unreachable (); +} + +std::vector +GenericTyPerCrateCtx::query_generic_variance (const ADTType &type) +{ + auto solution_index = lookup_type_index (type.get_orig_ref ()); + rust_assert (solution_index.has_value ()); + auto num_lifetimes = type.get_num_lifetime_params (); + auto num_types = type.get_num_type_params (); + + std::vector result; + for (size_t i = 0; i < num_lifetimes + num_types; ++i) + { + result.push_back (solutions[solution_index.value () + i]); + } + + return result; +} + +SolutionIndex +GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id) +{ + BaseType *ty = lookup_type (hir_id); + auto index = ctx.lookup_type_index (hir_id); + if (index.has_value ()) + { + return index.value (); + } + + SubstitutionRef *subst = nullptr; + if (auto adt = ty->try_as ()) + { + subst = adt; + } + else + { + rust_sorry_at ( + ty->get_locus (), + "This is a compiler bug: Unhandled type in variance analysis"); + } + rust_assert (subst != nullptr); + + auto solution_index = ctx.solutions.size (); + ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index); + + auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size (); + auto num_type_param = subst->get_num_substitutions (); + + for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i) + ctx.solutions.emplace_back (Variance::bivariant ()); + + return solution_index; +} + +void +GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance) +{ + rust_debug ("\tadd_constraint_from_ty: %s with v=%s", + type->as_string ().c_str (), ctx.to_string (variance).c_str ()); + + Visitor visitor (*this, variance); + type->accept_vis (visitor); +} + +void +GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term) +{ + rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ()); + + if (term.kind == Term::CONST) + { + // Constant terms do not depend on other solutions, so we can + // immediately apply them. + ctx.solutions[index].join (term.const_val); + } + else + { + ctx.constraints.push_back ({index, new Term (term)}); + } +} + +void +GenericTyVisitorCtx::add_constraints_from_region (const Region ®ion, + Term term) +{ + if (region.is_early_bound ()) + { + add_constraint (first_lifetime + region.get_index (), term); + } +} + +void +GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref, + SubstitutionRef &subst, + Term variance, + bool invariant_args) +{ + SolutionIndex solution_index = lookup_or_add_type (ref); + + size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size (); + size_t num_types = subst.get_substs ().size (); + + for (size_t i = 0; i < num_lifetimes + num_types; ++i) + { + // TODO: What about variance from other crates? + auto variance_i + = invariant_args + ? Term::make_transform (variance, Variance::invariant ()) + : Term::make_transform (variance, + Term::make_ref (solution_index + i)); + + if (i < num_lifetimes) + { + auto region_i = i; + auto ®ion + = subst.get_substitution_arguments ().get_mut_regions ()[region_i]; + add_constraints_from_region (region, variance_i); + } + else + { + auto type_i = i - num_lifetimes; + auto arg = subst.get_arg_at (type_i); + if (arg.has_value ()) + { + add_constraints_from_ty (arg.value ().get_tyty (), variance_i); + } + } + } +} +void +GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance) +{ + auto it + = std::find (param_names.begin (), param_names.end (), type.get_name ()); + rust_assert (it != param_names.end ()); + + auto index = first_type + std::distance (param_names.begin (), it); + + add_constraint (index, variance); +} + +Term +GenericTyVisitorCtx::contra (Term variance) +{ + return Term::make_transform (variance, Variance::contravariant ()); +} + +void +TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance) +{ + Visitor visitor (*this, variance); + ty->accept_vis (visitor); +} + +void +TyVisitorCtx::add_constraints_from_region (const Region ®ion, + Variance variance) +{ + variances.push_back (variance); + regions.push_back (region); +} + +void +TyVisitorCtx::add_constraints_from_generic_args (HirId ref, + SubstitutionRef &subst, + Variance variance, + bool invariant_args) +{ + // Handle function + auto variances + = ctx.query_generic_variance (*lookup_type (ref)->as ()); + + size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size (); + size_t num_types = subst.get_substs ().size (); + + for (size_t i = 0; i < num_lifetimes + num_types; ++i) + { + // TODO: What about variance from other crates? + auto variance_i + = invariant_args + ? Variance::transform (variance, Variance::invariant ()) + : Variance::transform (variance, variances[i]); + + if (i < num_lifetimes) + { + auto region_i = i; + auto ®ion = subst.get_used_arguments ().get_regions ()[region_i]; + add_constraints_from_region (region, variance_i); + } + else + { + auto type_i = i - num_lifetimes; + auto arg = subst.get_arg_at (type_i); + if (arg.has_value ()) + { + add_constraints_from_ty (arg.value ().get_tyty (), variance_i); + } + } + } +} + +Variance +TyVisitorCtx::contra (Variance variance) +{ + return Variance::transform (variance, Variance::contravariant ()); +} + +Term +Term::make_ref (SolutionIndex index) +{ + Term term; + term.kind = REF; + term.ref = index; + return term; +} + +Term +Term::make_transform (Term lhs, Term rhs) +{ + if (lhs.is_const () && rhs.is_const ()) + { + return Variance::transform (lhs.const_val, rhs.const_val); + } + + Term term; + term.kind = TRANSFORM; + term.transform.lhs = new Term (lhs); + term.transform.rhs = new Term (rhs); + return term; +} + +} // namespace VarianceAnalysis +} // namespace TyTy +} // namespace Rust \ No newline at end of file diff --git a/gcc/rust/typecheck/rust-tyty-variance-analysis.h b/gcc/rust/typecheck/rust-tyty-variance-analysis.h new file mode 100644 index 00000000000..a1defd62e54 --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-variance-analysis.h @@ -0,0 +1,114 @@ +#ifndef RUST_TYTY_VARIANCE_ANALYSIS_H +#define RUST_TYTY_VARIANCE_ANALYSIS_H + +#include "rust-tyty.h" + +namespace Rust { +namespace TyTy { +namespace VarianceAnalysis { + +class Variance; +class GenericTyPerCrateCtx; + +/** Per crate context for variance analysis. */ +class CrateCtx +{ +public: + CrateCtx (); + ~CrateCtx (); + + /** Add type to variance analysis context. */ + void add_type_constraints (ADTType &type); + + /** Solve all constraints and print debug output. */ + void solve (); + + /** Get variance of a type parameters. */ + std::vector query_generic_variance (const ADTType &type); + + /** Get variance of a type body (members, fn parameters...). */ + std::vector query_type_variances (BaseType *type); + + /** Get regions mentioned in a type. */ + std::vector query_type_regions (BaseType *type); + +private: + std::unique_ptr private_ctx; +}; + +/** Variance semilattice */ +class Variance +{ + enum Kind : uint8_t + { + BIVARIANT = 0, // 0b00 + COVARIANT = 1, // 0b01 + CONTRAVARIANT = 2, // 0b10 + INVARIANT = 3, // 0b11 + } kind; + + static constexpr auto TOP = BIVARIANT; + static constexpr auto BOTTOM = INVARIANT; + + constexpr Variance (Kind kind) : kind (kind) {} + +public: + constexpr Variance () : kind (TOP) {} + + WARN_UNUSED_RESULT constexpr bool is_bivariant () const + { + return kind == BIVARIANT; + } + WARN_UNUSED_RESULT constexpr bool is_covariant () const + { + return kind == COVARIANT; + } + WARN_UNUSED_RESULT constexpr bool is_contravariant () const + { + return kind == CONTRAVARIANT; + } + WARN_UNUSED_RESULT constexpr bool is_invariant () const + { + return kind == INVARIANT; + } + + static constexpr Variance bivariant () { return {BIVARIANT}; } + static constexpr Variance covariant () { return {COVARIANT}; } + static constexpr Variance contravariant () { return {CONTRAVARIANT}; } + static constexpr Variance invariant () { return {INVARIANT}; } + + WARN_UNUSED_RESULT Variance reverse () const; + static WARN_UNUSED_RESULT Variance join (Variance lhs, Variance rhs); + + void join (Variance rhs); + + /** + * Variance composition function. + * + * For `A` and `B` and the composition `A>` the variance of + * `v(A>, X)` is defined as: + * ``` + * v(A>, X) = v(A, X).transform(v(B, X)) + * ``` + */ + static WARN_UNUSED_RESULT Variance transform (Variance lhs, Variance rhs); + + constexpr friend bool operator== (const Variance &lhs, const Variance &rhs) + { + return lhs.kind == rhs.kind; + } + + constexpr friend bool operator!= (const Variance &lhs, const Variance &rhs) + { + return !(lhs == rhs); + } + + WARN_UNUSED_RESULT std::string as_string () const; +}; + +} // namespace VarianceAnalysis + +} // namespace TyTy +} // namespace Rust + +#endif // RUST_TYTY_VARIANCE_ANALYSIS_H