@@ -147,6 +147,7 @@ GRS_OBJS = \
rust/rust-tyty-util.o \
rust/rust-tyty-call.o \
rust/rust-tyty-subst.o \
+ rust/rust-tyty-variance-analysis.o \
rust/rust-typecheck-context.o \
rust/rust-tyty-bounds.o \
rust/rust-hir-trait-resolve.o \
@@ -48,6 +48,7 @@
#include "rust-attribute-values.h"
#include "rust-borrow-checker.h"
#include "rust-ast-validation.h"
+#include "rust-tyty-variance-analysis.h"
#include "input.h"
#include "selftest.h"
@@ -652,6 +653,8 @@ Session::compile_crate (const char *filename)
// type resolve
Resolver::TypeResolution::Resolve (hir);
+ Resolver::TypeCheckContext::get ()->get_variance_analysis_ctx ().solve ();
+
if (saw_errors ())
return;
@@ -24,6 +24,7 @@
#include "rust-hir-trait-reference.h"
#include "rust-autoderef.h"
#include "rust-tyty-region.h"
+#include "rust-tyty-variance-analysis.h"
#include <stack>
@@ -233,6 +234,8 @@ public:
void compute_inference_variables (bool error);
+ TyTy::VarianceAnalysis::CrateCtx &get_variance_analysis_ctx ();
+
private:
TypeCheckContext ();
@@ -272,6 +275,9 @@ private:
std::set<HirId> querys_in_progress;
std::set<DefId> trait_queries_in_progress;
+ // variance analysis
+ TyTy::VarianceAnalysis::CrateCtx variance_analysis_ctx;
+
/** Used to resolve (interned) lifetime names to their bounding scope. */
class LifetimeResolver
{
@@ -617,6 +617,12 @@ TypeCheckContext::compute_inference_variables (bool error)
});
}
+TyTy::VarianceAnalysis::CrateCtx &
+TypeCheckContext::get_variance_analysis_ctx ()
+{
+ return variance_analysis_ctx;
+}
+
// TypeCheckContextItem
TypeCheckContextItem::Item::Item (HIR::Function *item) : item (item) {}
new file mode 100644
@@ -0,0 +1,304 @@
+#ifndef RUST_TYTY_VARIANCE_ANALYSIS_PRIVATE_H
+#define RUST_TYTY_VARIANCE_ANALYSIS_PRIVATE_H
+
+#include "rust-tyty-variance-analysis.h"
+
+#include "rust-tyty-visitor.h"
+
+namespace Rust {
+namespace TyTy {
+namespace VarianceAnalysis {
+
+using SolutionIndex = uint32_t;
+
+/** Term descibing variance relations. */
+struct Term
+{
+ enum Kind : uint8_t
+ {
+ CONST,
+ REF,
+ TRANSFORM,
+ };
+
+ Kind kind;
+ union
+ {
+ struct
+ {
+ Term *lhs;
+ Term *rhs;
+ } transform;
+ SolutionIndex ref;
+ Variance const_val;
+ };
+
+ Term () {}
+
+ Term (Variance variance) : kind (CONST), const_val (variance) {}
+
+ WARN_UNUSED_RESULT bool is_const () const { return kind == CONST; }
+
+ static Term make_ref (SolutionIndex index);
+
+ static Term make_transform (Term lhs, Term rhs);
+};
+
+/** Variance constraint of a type parameter. */
+struct Constraint
+{
+ SolutionIndex target_index;
+ Term *term;
+};
+
+/** Abstract variance visitor context. */
+template <typename VARIANCE> class VarianceVisitorCtx
+{
+public:
+ virtual ~VarianceVisitorCtx () = default;
+
+ virtual void add_constraints_from_ty (BaseType *ty, VARIANCE variance) = 0;
+ virtual void add_constraints_from_region (const Region ®ion,
+ VARIANCE variance)
+ = 0;
+ void add_constraints_from_mutability (BaseType *type, Mutability mutability,
+ VARIANCE variance)
+ {
+ switch (mutability)
+ {
+ case Mutability::Imm:
+ return add_constraints_from_ty (type, variance);
+ case Mutability::Mut:
+ return add_constraints_from_ty (type, Variance::invariant ());
+ }
+ }
+ virtual void
+ add_constraints_from_generic_args (HirId ref, SubstitutionRef &subst,
+ VARIANCE variance, bool invariant_args)
+ = 0;
+ virtual void add_constrints_from_param (ParamType ¶m, VARIANCE variance)
+ = 0;
+ virtual VARIANCE contra (VARIANCE variance) = 0;
+};
+
+template <typename VARIANCE> class VisitorBase final : public TyVisitor
+{
+ VarianceVisitorCtx<VARIANCE> &ctx;
+ VARIANCE variance;
+
+public:
+ VisitorBase (VarianceVisitorCtx<VARIANCE> &ctx, VARIANCE variance)
+ : ctx (ctx), variance (variance)
+ {}
+
+ void visit (BoolType &type) override {}
+ void visit (CharType &type) override {}
+ void visit (IntType &type) override {}
+ void visit (UintType &type) override {}
+ void visit (FloatType &type) override {}
+ void visit (USizeType &type) override {}
+ void visit (ISizeType &type) override {}
+ void visit (StrType &type) override {}
+ void visit (NeverType &type) override {}
+
+ void visit (ClosureType &type) override {}
+ void visit (FnType &type) override
+ {
+ for (auto ®ion : type.get_used_arguments ().get_regions ())
+ ctx.add_constraints_from_region (region, Variance::invariant ());
+ }
+
+ void visit (ReferenceType &type) override
+ {
+ ctx.add_constraints_from_region (type.get_region (), variance);
+ ctx.add_constraints_from_mutability (type.get_base (), type.mutability (),
+ variance);
+ }
+ void visit (ArrayType &type) override
+ {
+ ctx.add_constraints_from_ty (type.get_element_type (), variance);
+ }
+ void visit (SliceType &type) override
+ {
+ ctx.add_constraints_from_ty (type.get_element_type (), variance);
+ }
+ void visit (PointerType &type) override
+ {
+ ctx.add_constraints_from_ty (type.get_base (), variance);
+ ctx.add_constraints_from_mutability (type.get_base (), type.mutability (),
+ variance);
+ }
+ void visit (TupleType &type) override
+ {
+ for (auto &elem : type.get_fields ())
+ ctx.add_constraints_from_ty (elem.get_tyty (), variance);
+ }
+ void visit (ADTType &type) override
+ {
+ ctx.add_constraints_from_generic_args (type.get_orig_ref (), type, variance,
+ false);
+ }
+ void visit (ProjectionType &type) override
+ {
+ ctx.add_constraints_from_generic_args (type.get_orig_ref (), type, variance,
+ true);
+ }
+ void visit (ParamType &type) override
+ {
+ ctx.add_constrints_from_param (type, variance);
+ }
+ void visit (FnPtr &type) override
+ {
+ auto contra = ctx.contra (variance);
+
+ for (auto ¶m : type.get_params ())
+ {
+ ctx.add_constraints_from_ty (param.get_tyty (), contra);
+ }
+
+ ctx.add_constraints_from_ty (type.get_return_type (), variance);
+ }
+
+ void visit (ErrorType &type) override {}
+
+ void visit (PlaceholderType &type) override { rust_unreachable (); }
+ void visit (InferType &type) override { rust_unreachable (); }
+
+ void visit (DynamicObjectType &type) override
+ {
+ // TODO
+ }
+};
+
+/** Per crate context for generic type variance analysis. */
+class GenericTyPerCrateCtx
+{
+public: // External API
+ /** Add a type to context and process its variance constraints. */
+ void process_type (ADTType &ty);
+
+ /**
+ * Solve for all variance constraints and clear temporary data.
+ *
+ * Only keeps the results.
+ */
+ void solve ();
+
+ /** Prints solution debug output. To be called after solve. */
+ void debug_print_solutions ();
+
+ tl::optional<SolutionIndex> lookup_type_index (HirId orig_ref);
+
+public: // Module internal API
+ /** Format term tree to string. */
+ WARN_UNUSED_RESULT std::string to_string (const Term &term) const;
+
+ /** Formats as <type ident>`[`<param index>``]` */
+ WARN_UNUSED_RESULT std::string to_string (SolutionIndex index) const;
+
+ /** Evaluate a variance relation expression (term tree). */
+ Variance evaluate (Term *term);
+
+ std::vector<Variance> query_generic_variance (const ADTType &type);
+
+public: // Data used by visitors.
+ // This whole class is private, therfore members can be public.
+
+ /** Current solutions. Initiated to bivariant. */
+ std::vector<Variance> solutions;
+
+ /** Constrains on solutions. Iteratively applied until fixpoint. */
+ std::vector<Constraint> constraints;
+
+ /** Maps TyTy::orig_ref to an index of first solution for this type. */
+ std::unordered_map<HirId, SolutionIndex> map_from_ty_orig_ref;
+};
+
+/** Visitor context for generic type variance analysis used for processing of a
+ * single type. */
+class GenericTyVisitorCtx : VarianceVisitorCtx<Term>
+{
+ using Visitor = VisitorBase<Term>;
+
+public:
+ explicit GenericTyVisitorCtx (GenericTyPerCrateCtx &ctx) : ctx (ctx) {}
+ /** Entry point: Add a type to context and process its variance constraints.
+ */
+ void process_type (ADTType &ty);
+
+private:
+ /** Resolve a type from a TyTy::ref. */
+ SolutionIndex lookup_or_add_type (HirId hir_id);
+
+ /** Visit an inner type and add its constraints. */
+ void add_constraints_from_ty (BaseType *ty, Term variance) override;
+
+ void add_constraint (SolutionIndex index, Term term);
+
+ void add_constraints_from_region (const Region ®ion, Term term) override;
+
+ void add_constraints_from_generic_args (HirId ref, SubstitutionRef &subst,
+ Term variance,
+ bool invariant_args) override;
+
+ void add_constrints_from_param (ParamType &type, Term variance) override;
+
+ /** Construct a term for type in contravaraint position. */
+ Term contra (Term variance) override;
+
+private:
+ GenericTyPerCrateCtx &ctx;
+
+private: // Per type processing context
+ /** Index of the solution first **lifetime param** for the current type. */
+ SolutionIndex first_lifetime = 0;
+
+ /** Index of the solution first **type param** for the current type. */
+ SolutionIndex first_type = 0;
+
+ /** Maps type param names to index among type params. */
+ std::vector<std::string> param_names;
+};
+
+/** Visitor context for basic type variance analysis. */
+class TyVisitorCtx : public VarianceVisitorCtx<Variance>
+{
+public:
+ using Visitor = VisitorBase<Variance>;
+
+ TyVisitorCtx (GenericTyPerCrateCtx &ctx) : ctx (ctx) {}
+
+ std::vector<Variance> collect_variances (BaseType &ty)
+ {
+ add_constraints_from_ty (&ty, Variance::covariant ());
+ return variances;
+ }
+
+ std::vector<Region> collect_regions (BaseType &ty)
+ {
+ add_constraints_from_ty (&ty, Variance::covariant ());
+ return regions;
+ }
+
+ void add_constraints_from_ty (BaseType *ty, Variance variance) override;
+ void add_constraints_from_region (const Region ®ion,
+ Variance variance) override;
+ void add_constraints_from_generic_args (HirId ref, SubstitutionRef &subst,
+ Variance variance,
+ bool invariant_args) override;
+ void add_constrints_from_param (ParamType ¶m, Variance variance) override
+ {}
+ Variance contra (Variance variance) override;
+
+private:
+ GenericTyPerCrateCtx &ctx;
+ std::vector<Variance> variances;
+ std::vector<Region> regions;
+};
+
+} // namespace VarianceAnalysis
+
+} // namespace TyTy
+} // namespace Rust
+
+#endif // RUST_TYTY_VARIANCE_ANALYSIS_PRIVATE_H
new file mode 100644
@@ -0,0 +1,541 @@
+#include "rust-tyty-variance-analysis-private.h"
+#include "rust-hir-type-check.h"
+
+namespace Rust {
+namespace TyTy {
+
+BaseType *
+lookup_type (HirId ref)
+{
+ BaseType *ty = nullptr;
+ bool ok = Resolver::TypeCheckContext::get ()->lookup_type (ref, &ty);
+ rust_assert (ok);
+ return ty;
+}
+
+namespace VarianceAnalysis {
+
+CrateCtx::CrateCtx () : private_ctx (new GenericTyPerCrateCtx ()) {}
+
+// Must be here because of incomplete type.
+CrateCtx::~CrateCtx () = default;
+
+void
+CrateCtx::add_type_constraints (ADTType &type)
+{
+ private_ctx->process_type (type);
+}
+
+void
+CrateCtx::solve ()
+{
+ private_ctx->solve ();
+ private_ctx->debug_print_solutions ();
+}
+
+std::vector<Variance>
+CrateCtx::query_generic_variance (const ADTType &type)
+{
+ return private_ctx->query_generic_variance (type);
+}
+
+std::vector<Variance>
+CrateCtx::query_type_variances (BaseType *type)
+{
+ TyVisitorCtx ctx (*private_ctx);
+ return ctx.collect_variances (*type);
+}
+
+std::vector<Region>
+CrateCtx::query_type_regions (BaseType *type)
+{
+ TyVisitorCtx ctx (*private_ctx);
+ return ctx.collect_regions (*type);
+}
+
+Variance
+Variance::reverse () const
+{
+ switch (kind)
+ {
+ case BIVARIANT:
+ return bivariant ();
+ case COVARIANT:
+ return contravariant ();
+ case CONTRAVARIANT:
+ return covariant ();
+ case INVARIANT:
+ return invariant ();
+ }
+
+ rust_unreachable ();
+}
+
+Variance
+Variance::join (Variance lhs, Variance rhs)
+{
+ return {Kind (lhs.kind | rhs.kind)};
+}
+
+void
+Variance::join (Variance rhs)
+{
+ *this = join (*this, rhs);
+}
+
+Variance
+Variance::transform (Variance lhs, Variance rhs)
+{
+ switch (lhs.kind)
+ {
+ case BIVARIANT:
+ return bivariant ();
+ case COVARIANT:
+ return rhs;
+ case CONTRAVARIANT:
+ return rhs.reverse ();
+ case INVARIANT:
+ return invariant ();
+ }
+ rust_unreachable ();
+}
+
+std::string
+Variance::as_string () const
+{
+ switch (kind)
+ {
+ case BIVARIANT:
+ return "o";
+ case COVARIANT:
+ return "+";
+ case CONTRAVARIANT:
+ return "-";
+ case INVARIANT:
+ return "*";
+ }
+ rust_unreachable ();
+}
+
+void
+GenericTyPerCrateCtx::process_type (ADTType &type)
+{
+ GenericTyVisitorCtx (*this).process_type (type);
+}
+
+void
+GenericTyPerCrateCtx::solve ()
+{
+ rust_debug ("Variance analysis solving started:");
+
+ // Fix point iteration
+ bool changed = true;
+ while (changed)
+ {
+ changed = false;
+ for (auto constraint : constraints)
+ {
+ rust_debug ("\tapplying constraint: %s <= %s",
+ to_string (constraint.target_index).c_str (),
+ to_string (*constraint.term).c_str ());
+
+ auto old_solution = solutions[constraint.target_index];
+ auto new_solution
+ = Variance::join (old_solution, evaluate (constraint.term));
+
+ if (old_solution != new_solution)
+ {
+ rust_debug ("\t\tsolution changed: %s => %s",
+ old_solution.as_string ().c_str (),
+ new_solution.as_string ().c_str ());
+
+ changed = true;
+ solutions[constraint.target_index] = new_solution;
+ }
+ }
+ }
+
+ constraints.clear ();
+ constraints.shrink_to_fit ();
+}
+
+void
+GenericTyPerCrateCtx::debug_print_solutions ()
+{
+ rust_debug ("Variance analysis results:");
+
+ for (auto type : map_from_ty_orig_ref)
+ {
+ auto solution_index = type.second;
+ auto ref = type.first;
+
+ BaseType *ty = lookup_type (ref);
+
+ std::string result = "\t";
+
+ if (auto adt = ty->try_as<ADTType> ())
+ {
+ result += adt->get_identifier ();
+ result += "<";
+
+ size_t i = solution_index;
+ for (auto ®ion : adt->get_used_arguments ().get_regions ())
+ {
+ (void) region;
+ if (i > solution_index)
+ result += ", ";
+ result += solutions[i].as_string ();
+ i++;
+ }
+ for (auto ¶m : adt->get_substs ())
+ {
+ if (i > solution_index)
+ result += ", ";
+ result += param.get_generic_param ()
+ .get_type_representation ()
+ .as_string ();
+ result += "=";
+ result += solutions[i].as_string ();
+ i++;
+ }
+
+ result += ">";
+ }
+ else
+ {
+ rust_sorry_at (
+ ty->get_ref (),
+ "This is a compiler bug: Unhandled type in variance analysis");
+ }
+ rust_debug ("%s", result.c_str ());
+ }
+}
+
+tl::optional<SolutionIndex>
+GenericTyPerCrateCtx::lookup_type_index (HirId orig_ref)
+{
+ auto it = map_from_ty_orig_ref.find (orig_ref);
+ if (it != map_from_ty_orig_ref.end ())
+ {
+ return it->second;
+ }
+ return tl::nullopt;
+}
+
+void
+GenericTyVisitorCtx::process_type (ADTType &ty)
+{
+ rust_debug ("add_type_constraints: %s", ty.as_string ().c_str ());
+
+ first_lifetime = lookup_or_add_type (ty.get_orig_ref ());
+ first_type = first_lifetime + ty.get_used_arguments ().get_regions ().size ();
+
+ for (const auto ¶m : ty.get_substs ())
+ param_names.push_back (
+ param.get_generic_param ().get_type_representation ().as_string ());
+
+ for (const auto &variant : ty.get_variants ())
+ {
+ if (variant->get_variant_type () != VariantDef::NUM)
+ {
+ for (const auto &field : variant->get_fields ())
+ add_constraints_from_ty (field->get_field_type (),
+ Variance::covariant ());
+ }
+ }
+}
+
+std::string
+GenericTyPerCrateCtx::to_string (const Term &term) const
+{
+ switch (term.kind)
+ {
+ case Term::CONST:
+ return term.const_val.as_string ();
+ case Term::REF:
+ return "v(" + to_string (term.ref) + ")";
+ case Term::TRANSFORM:
+ return "(" + to_string (*term.transform.lhs) + " x "
+ + to_string (*term.transform.rhs) + ")";
+ }
+ rust_unreachable ();
+}
+
+std::string
+GenericTyPerCrateCtx::to_string (SolutionIndex index) const
+{
+ // Search all values in def_id_to_solution_index_start and find key for
+ // largest value smaller than index
+ std::pair<HirId, SolutionIndex> best = {0, 0};
+
+ for (const auto &ty_map : map_from_ty_orig_ref)
+ {
+ if (ty_map.second <= index && ty_map.first > best.first)
+ best = ty_map;
+ }
+ rust_assert (best.first != 0);
+
+ BaseType *ty = lookup_type (best.first);
+
+ std::string result = "";
+ if (auto adt = ty->try_as<ADTType> ())
+ {
+ result += (adt->get_identifier ());
+ }
+ else
+ {
+ result += ty->as_string ();
+ }
+
+ result += "[" + std::to_string (index - best.first) + "]";
+ return result;
+}
+
+Variance
+GenericTyPerCrateCtx::evaluate (Term *term)
+{
+ switch (term->kind)
+ {
+ case Term::CONST:
+ return term->const_val;
+ case Term::REF:
+ return solutions[term->ref];
+ case Term::TRANSFORM:
+ return Variance::transform (evaluate (term->transform.lhs),
+ evaluate (term->transform.rhs));
+ }
+ rust_unreachable ();
+}
+
+std::vector<Variance>
+GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
+{
+ auto solution_index = lookup_type_index (type.get_orig_ref ());
+ rust_assert (solution_index.has_value ());
+ auto num_lifetimes = type.get_num_lifetime_params ();
+ auto num_types = type.get_num_type_params ();
+
+ std::vector<Variance> result;
+ for (size_t i = 0; i < num_lifetimes + num_types; ++i)
+ {
+ result.push_back (solutions[solution_index.value () + i]);
+ }
+
+ return result;
+}
+
+SolutionIndex
+GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
+{
+ BaseType *ty = lookup_type (hir_id);
+ auto index = ctx.lookup_type_index (hir_id);
+ if (index.has_value ())
+ {
+ return index.value ();
+ }
+
+ SubstitutionRef *subst = nullptr;
+ if (auto adt = ty->try_as<ADTType> ())
+ {
+ subst = adt;
+ }
+ else
+ {
+ rust_sorry_at (
+ ty->get_locus (),
+ "This is a compiler bug: Unhandled type in variance analysis");
+ }
+ rust_assert (subst != nullptr);
+
+ auto solution_index = ctx.solutions.size ();
+ ctx.map_from_ty_orig_ref.emplace (ty->get_orig_ref (), solution_index);
+
+ auto num_lifetime_param = subst->get_used_arguments ().get_regions ().size ();
+ auto num_type_param = subst->get_num_substitutions ();
+
+ for (size_t i = 0; i < num_lifetime_param + num_type_param; ++i)
+ ctx.solutions.emplace_back (Variance::bivariant ());
+
+ return solution_index;
+}
+
+void
+GenericTyVisitorCtx::add_constraints_from_ty (BaseType *type, Term variance)
+{
+ rust_debug ("\tadd_constraint_from_ty: %s with v=%s",
+ type->as_string ().c_str (), ctx.to_string (variance).c_str ());
+
+ Visitor visitor (*this, variance);
+ type->accept_vis (visitor);
+}
+
+void
+GenericTyVisitorCtx::add_constraint (SolutionIndex index, Term term)
+{
+ rust_debug ("\t\tadd_constraint: %s", ctx.to_string (term).c_str ());
+
+ if (term.kind == Term::CONST)
+ {
+ // Constant terms do not depend on other solutions, so we can
+ // immediately apply them.
+ ctx.solutions[index].join (term.const_val);
+ }
+ else
+ {
+ ctx.constraints.push_back ({index, new Term (term)});
+ }
+}
+
+void
+GenericTyVisitorCtx::add_constraints_from_region (const Region ®ion,
+ Term term)
+{
+ if (region.is_early_bound ())
+ {
+ add_constraint (first_lifetime + region.get_index (), term);
+ }
+}
+
+void
+GenericTyVisitorCtx::add_constraints_from_generic_args (HirId ref,
+ SubstitutionRef &subst,
+ Term variance,
+ bool invariant_args)
+{
+ SolutionIndex solution_index = lookup_or_add_type (ref);
+
+ size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
+ size_t num_types = subst.get_substs ().size ();
+
+ for (size_t i = 0; i < num_lifetimes + num_types; ++i)
+ {
+ // TODO: What about variance from other crates?
+ auto variance_i
+ = invariant_args
+ ? Term::make_transform (variance, Variance::invariant ())
+ : Term::make_transform (variance,
+ Term::make_ref (solution_index + i));
+
+ if (i < num_lifetimes)
+ {
+ auto region_i = i;
+ auto ®ion
+ = subst.get_substitution_arguments ().get_mut_regions ()[region_i];
+ add_constraints_from_region (region, variance_i);
+ }
+ else
+ {
+ auto type_i = i - num_lifetimes;
+ auto arg = subst.get_arg_at (type_i);
+ if (arg.has_value ())
+ {
+ add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
+ }
+ }
+ }
+}
+void
+GenericTyVisitorCtx::add_constrints_from_param (ParamType &type, Term variance)
+{
+ auto it
+ = std::find (param_names.begin (), param_names.end (), type.get_name ());
+ rust_assert (it != param_names.end ());
+
+ auto index = first_type + std::distance (param_names.begin (), it);
+
+ add_constraint (index, variance);
+}
+
+Term
+GenericTyVisitorCtx::contra (Term variance)
+{
+ return Term::make_transform (variance, Variance::contravariant ());
+}
+
+void
+TyVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
+{
+ Visitor visitor (*this, variance);
+ ty->accept_vis (visitor);
+}
+
+void
+TyVisitorCtx::add_constraints_from_region (const Region ®ion,
+ Variance variance)
+{
+ variances.push_back (variance);
+ regions.push_back (region);
+}
+
+void
+TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
+ SubstitutionRef &subst,
+ Variance variance,
+ bool invariant_args)
+{
+ // Handle function
+ auto variances
+ = ctx.query_generic_variance (*lookup_type (ref)->as<ADTType> ());
+
+ size_t num_lifetimes = subst.get_used_arguments ().get_regions ().size ();
+ size_t num_types = subst.get_substs ().size ();
+
+ for (size_t i = 0; i < num_lifetimes + num_types; ++i)
+ {
+ // TODO: What about variance from other crates?
+ auto variance_i
+ = invariant_args
+ ? Variance::transform (variance, Variance::invariant ())
+ : Variance::transform (variance, variances[i]);
+
+ if (i < num_lifetimes)
+ {
+ auto region_i = i;
+ auto ®ion = subst.get_used_arguments ().get_regions ()[region_i];
+ add_constraints_from_region (region, variance_i);
+ }
+ else
+ {
+ auto type_i = i - num_lifetimes;
+ auto arg = subst.get_arg_at (type_i);
+ if (arg.has_value ())
+ {
+ add_constraints_from_ty (arg.value ().get_tyty (), variance_i);
+ }
+ }
+ }
+}
+
+Variance
+TyVisitorCtx::contra (Variance variance)
+{
+ return Variance::transform (variance, Variance::contravariant ());
+}
+
+Term
+Term::make_ref (SolutionIndex index)
+{
+ Term term;
+ term.kind = REF;
+ term.ref = index;
+ return term;
+}
+
+Term
+Term::make_transform (Term lhs, Term rhs)
+{
+ if (lhs.is_const () && rhs.is_const ())
+ {
+ return Variance::transform (lhs.const_val, rhs.const_val);
+ }
+
+ Term term;
+ term.kind = TRANSFORM;
+ term.transform.lhs = new Term (lhs);
+ term.transform.rhs = new Term (rhs);
+ return term;
+}
+
+} // namespace VarianceAnalysis
+} // namespace TyTy
+} // namespace Rust
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,114 @@
+#ifndef RUST_TYTY_VARIANCE_ANALYSIS_H
+#define RUST_TYTY_VARIANCE_ANALYSIS_H
+
+#include "rust-tyty.h"
+
+namespace Rust {
+namespace TyTy {
+namespace VarianceAnalysis {
+
+class Variance;
+class GenericTyPerCrateCtx;
+
+/** Per crate context for variance analysis. */
+class CrateCtx
+{
+public:
+ CrateCtx ();
+ ~CrateCtx ();
+
+ /** Add type to variance analysis context. */
+ void add_type_constraints (ADTType &type);
+
+ /** Solve all constraints and print debug output. */
+ void solve ();
+
+ /** Get variance of a type parameters. */
+ std::vector<Variance> query_generic_variance (const ADTType &type);
+
+ /** Get variance of a type body (members, fn parameters...). */
+ std::vector<Variance> query_type_variances (BaseType *type);
+
+ /** Get regions mentioned in a type. */
+ std::vector<Region> query_type_regions (BaseType *type);
+
+private:
+ std::unique_ptr<GenericTyPerCrateCtx> private_ctx;
+};
+
+/** Variance semilattice */
+class Variance
+{
+ enum Kind : uint8_t
+ {
+ BIVARIANT = 0, // 0b00
+ COVARIANT = 1, // 0b01
+ CONTRAVARIANT = 2, // 0b10
+ INVARIANT = 3, // 0b11
+ } kind;
+
+ static constexpr auto TOP = BIVARIANT;
+ static constexpr auto BOTTOM = INVARIANT;
+
+ constexpr Variance (Kind kind) : kind (kind) {}
+
+public:
+ constexpr Variance () : kind (TOP) {}
+
+ WARN_UNUSED_RESULT constexpr bool is_bivariant () const
+ {
+ return kind == BIVARIANT;
+ }
+ WARN_UNUSED_RESULT constexpr bool is_covariant () const
+ {
+ return kind == COVARIANT;
+ }
+ WARN_UNUSED_RESULT constexpr bool is_contravariant () const
+ {
+ return kind == CONTRAVARIANT;
+ }
+ WARN_UNUSED_RESULT constexpr bool is_invariant () const
+ {
+ return kind == INVARIANT;
+ }
+
+ static constexpr Variance bivariant () { return {BIVARIANT}; }
+ static constexpr Variance covariant () { return {COVARIANT}; }
+ static constexpr Variance contravariant () { return {CONTRAVARIANT}; }
+ static constexpr Variance invariant () { return {INVARIANT}; }
+
+ WARN_UNUSED_RESULT Variance reverse () const;
+ static WARN_UNUSED_RESULT Variance join (Variance lhs, Variance rhs);
+
+ void join (Variance rhs);
+
+ /**
+ * Variance composition function.
+ *
+ * For `A<X>` and `B<X>` and the composition `A<B<X>>` the variance of
+ * `v(A<B<X>>, X)` is defined as:
+ * ```
+ * v(A<B<X>>, X) = v(A<X>, X).transform(v(B<X>, X))
+ * ```
+ */
+ static WARN_UNUSED_RESULT Variance transform (Variance lhs, Variance rhs);
+
+ constexpr friend bool operator== (const Variance &lhs, const Variance &rhs)
+ {
+ return lhs.kind == rhs.kind;
+ }
+
+ constexpr friend bool operator!= (const Variance &lhs, const Variance &rhs)
+ {
+ return !(lhs == rhs);
+ }
+
+ WARN_UNUSED_RESULT std::string as_string () const;
+};
+
+} // namespace VarianceAnalysis
+
+} // namespace TyTy
+} // namespace Rust
+
+#endif // RUST_TYTY_VARIANCE_ANALYSIS_H