@@ -201,6 +201,13 @@ public: // Module internal API
std::vector<Variance> query_generic_variance (const ADTType &type);
+ std::vector<size_t> query_field_regions (const ADTType *parent,
+ size_t variant_index,
+ size_t field_index,
+ const FreeRegions &parent_regions);
+
+ std::vector<Region> query_type_regions (BaseType *base);
+
public: // Data used by visitors.
// This whole class is private, therfore members can be public.
@@ -296,6 +303,40 @@ private:
std::vector<Region> regions;
};
+/** Extracts regions of a field from regions of parent ADT. */
+class FieldVisitorCtx : public VarianceVisitorCtx<Variance>
+{
+public:
+ using Visitor = VisitorBase<Variance>;
+
+ std::vector<size_t> collect_regions (BaseType &ty);
+
+ FieldVisitorCtx (GenericTyPerCrateCtx &ctx, const SubstitutionRef &subst,
+ const FreeRegions &parent_regions)
+ : ctx (ctx), subst (subst), parent_regions (parent_regions)
+ {}
+
+ void add_constraints_from_ty (BaseType *ty, Variance variance) override;
+ void add_constraints_from_region (const Region ®ion,
+ Variance variance) override;
+ void add_constraints_from_generic_args (HirId ref, SubstitutionRef &subst,
+ Variance variance,
+ bool invariant_args) override{};
+ void add_constrints_from_param (ParamType ¶m, Variance variance) override;
+
+ Variance contra (Variance variance) override
+ {
+ return Variance::transform (variance, Variance::contravariant ());
+ }
+
+private:
+ GenericTyPerCrateCtx &ctx;
+ const SubstitutionRef &subst;
+ std::vector<size_t> regions;
+ FreeRegions parent_regions;
+ std::vector<size_t> type_param_ranges;
+};
+
} // namespace VarianceAnalysis
} // namespace TyTy
@@ -49,8 +49,16 @@ CrateCtx::query_type_variances (BaseType *type)
std::vector<Region>
CrateCtx::query_type_regions (BaseType *type)
{
- TyVisitorCtx ctx (*private_ctx);
- return ctx.collect_regions (*type);
+ return private_ctx->query_type_regions (type);
+}
+
+std::vector<size_t>
+CrateCtx::query_field_regions (const ADTType *parent, size_t variant_index,
+ size_t field_index,
+ const FreeRegions &parent_regions)
+{
+ return private_ctx->query_field_regions (parent, variant_index, field_index,
+ parent_regions);
}
Variance
@@ -324,6 +332,29 @@ GenericTyPerCrateCtx::query_generic_variance (const ADTType &type)
return result;
}
+std::vector<size_t>
+GenericTyPerCrateCtx::query_field_regions (const ADTType *parent,
+ size_t variant_index,
+ size_t field_index,
+ const FreeRegions &parent_regions)
+{
+ auto orig = lookup_type (parent->get_orig_ref ());
+ FieldVisitorCtx ctx (*this, *parent->as<const SubstitutionRef> (),
+ parent_regions);
+ return ctx.collect_regions (*orig->as<const ADTType> ()
+ ->get_variants ()
+ .at (variant_index)
+ ->get_fields ()
+ .at (field_index)
+ ->get_field_type ());
+}
+std::vector<Region>
+GenericTyPerCrateCtx::query_type_regions (BaseType *type)
+{
+ TyVisitorCtx ctx (*this);
+ return ctx.collect_regions (*type);
+}
+
SolutionIndex
GenericTyVisitorCtx::lookup_or_add_type (HirId hir_id)
{
@@ -506,6 +537,58 @@ TyVisitorCtx::add_constraints_from_generic_args (HirId ref,
}
}
+std::vector<size_t>
+FieldVisitorCtx::collect_regions (BaseType &ty)
+{
+ // Segment the regions into ranges for each type parameter. Type parameter
+ // at index i contains regions from type_param_ranges[i] to
+ // type_param_ranges[i+1] (exclusive).;
+ type_param_ranges.push_back (subst.get_num_lifetime_params ());
+
+ for (size_t i = 0; i < subst.get_num_type_params (); i++)
+ {
+ auto arg = subst.get_arg_at (i);
+ rust_assert (arg.has_value ());
+ type_param_ranges.push_back (
+ ctx.query_type_regions (arg.value ().get_tyty ()).size ());
+ }
+
+ add_constraints_from_ty (&ty, Variance::covariant ());
+ return regions;
+}
+
+void
+FieldVisitorCtx::add_constraints_from_ty (BaseType *ty, Variance variance)
+{
+ Visitor visitor (*this, variance);
+ ty->accept_vis (visitor);
+}
+
+void
+FieldVisitorCtx::add_constraints_from_region (const Region ®ion,
+ Variance variance)
+{
+ if (region.is_early_bound ())
+ {
+ regions.push_back (parent_regions[region.get_index ()]);
+ }
+ else if (region.is_late_bound ())
+ {
+ rust_debug ("Ignoring late bound region");
+ }
+}
+
+void
+FieldVisitorCtx::add_constrints_from_param (ParamType ¶m, Variance variance)
+{
+ size_t param_i = subst.get_used_arguments ().find_symbol (param).value ();
+ for (size_t i = type_param_ranges[param_i];
+ i < type_param_ranges[param_i + 1]; i++)
+ {
+ regions.push_back (parent_regions[i]);
+ }
+}
+
Variance
TyVisitorCtx::contra (Variance variance)
{
@@ -3,6 +3,8 @@
#include "rust-tyty.h"
+#include <rust-bir-free-region.h>
+
namespace Rust {
namespace TyTy {
namespace VarianceAnalysis {
@@ -31,11 +33,19 @@ public:
/** Get regions mentioned in a type. */
std::vector<Region> query_type_regions (BaseType *type);
+ std::vector<size_t> query_field_regions (const ADTType *parent,
+ size_t variant_index,
+ size_t field_index,
+ const FreeRegions &parent_regions);
private:
std::unique_ptr<GenericTyPerCrateCtx> private_ctx;
};
+std::vector<size_t>
+query_field_regions (const ADTType *parent, size_t variant_index,
+ size_t field_index, const FreeRegions &parent_regions);
+
/** Variance semilattice */
class Variance
{
@@ -55,22 +65,10 @@ class Variance
public:
constexpr Variance () : kind (TOP) {}
- WARN_UNUSED_RESULT constexpr bool is_bivariant () const
- {
- return kind == BIVARIANT;
- }
- WARN_UNUSED_RESULT constexpr bool is_covariant () const
- {
- return kind == COVARIANT;
- }
- WARN_UNUSED_RESULT constexpr bool is_contravariant () const
- {
- return kind == CONTRAVARIANT;
- }
- WARN_UNUSED_RESULT constexpr bool is_invariant () const
- {
- return kind == INVARIANT;
- }
+ constexpr bool is_bivariant () const { return kind == BIVARIANT; }
+ constexpr bool is_covariant () const { return kind == COVARIANT; }
+ constexpr bool is_contravariant () const { return kind == CONTRAVARIANT; }
+ constexpr bool is_invariant () const { return kind == INVARIANT; }
static constexpr Variance bivariant () { return {BIVARIANT}; }
static constexpr Variance covariant () { return {COVARIANT}; }
@@ -97,7 +95,6 @@ public:
{
return lhs.kind == rhs.kind;
}
-
constexpr friend bool operator!= (const Variance &lhs, const Variance &rhs)
{
return !(lhs == rhs);