From patchwork Tue Dec 6 10:14:01 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Arthur Cohen X-Patchwork-Id: 61521 Return-Path: X-Original-To: patchwork@sourceware.org Delivered-To: patchwork@sourceware.org Received: from server2.sourceware.org (localhost [IPv6:::1]) by sourceware.org (Postfix) with ESMTP id EF7553836BB7 for ; Tue, 6 Dec 2022 10:21:38 +0000 (GMT) X-Original-To: gcc-patches@gcc.gnu.org Delivered-To: gcc-patches@gcc.gnu.org Received: from mail-wm1-x332.google.com (mail-wm1-x332.google.com [IPv6:2a00:1450:4864:20::332]) by sourceware.org (Postfix) with ESMTPS id 31D77395B834 for ; Tue, 6 Dec 2022 10:12:34 +0000 (GMT) DMARC-Filter: OpenDMARC Filter v1.4.1 sourceware.org 31D77395B834 Authentication-Results: sourceware.org; dmarc=none (p=none dis=none) header.from=embecosm.com Authentication-Results: sourceware.org; spf=pass smtp.mailfrom=embecosm.com Received: by mail-wm1-x332.google.com with SMTP id n7so10787015wms.3 for ; Tue, 06 Dec 2022 02:12:34 -0800 (PST) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=embecosm.com; s=google; h=content-transfer-encoding:mime-version:reply-to:references :in-reply-to:message-id:date:subject:cc:to:from:from:to:cc:subject :date:message-id:reply-to; bh=MUXjs+CR4b4IFwySUiXWv8Ms6ayeQ2ct2uMn1fvWRAM=; b=Lu69o3o6XWC+kxqWU72ecjrO4hr518xQvu5pwxCLbFCtMRbicenA7BoESkBf5o/ic7 u7zDCB85YWyMhCDFwinyxtC3DXHeMgT0hCErp4VVRF5IjXErpuxCPGoRGXrGtMKX4WLL zQLQEjlwUy35GnHDFiGmxwXPn0KdJvu6XcHFvdwSx23t7esdzvUQRvhgyRmeHmmVtLi9 kxLo+BGxiUSL9hcXFAplscIDUMuBQ44CnwTCaLyXsBXz2T7BjLW0UvKa8Y8tr03M1ZxF 5hifzS+fSDmuzpI55L/ESI2eiDJBoH/iJc9ZEPOtCNbfF2yzGF99D8RWNGjJ+WrjUh/i XEgg== X-Google-DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20210112; h=content-transfer-encoding:mime-version:reply-to:references :in-reply-to:message-id:date:subject:cc:to:from:x-gm-message-state :from:to:cc:subject:date:message-id:reply-to; bh=MUXjs+CR4b4IFwySUiXWv8Ms6ayeQ2ct2uMn1fvWRAM=; b=GX6ldqogqIX5oxosaJNST5D5y202aDpYwMOtMm1H5Nl+evaInI3f2dU0Ke4RriZzow ORAo8pteFSOGZO/6C0v0uJWp0WHKAr9kRQPNAulEYAQi1Q8rrANbWdh/BP474yqQU/Na HICFggbe5EyjDL/VPOVtX1iq5yjBYVrq0errXRSefuOXqH78rYZBo1X2M1bDVo/u03EY Mz847HPFlSKLDN5HZmFAnv20Nc6JaxPVAxyHvHGkHCSVmeYk9X3kIXBarIebx2dVfNgd gFs1BDBtNI9NR5uWlcUbCsvvKw4WFAg8H+jJklstfxTgdzJMNGIZ5J9qaKwksT60wSHI fCiA== X-Gm-Message-State: ANoB5pmmRNd54MKfc4nvZ5AkiMUHngWYhQb03Qmp38+yFL79ycfa/ek/ AJVSYRhIfGiaxvjlEMMyiillpFKgBDBIwNs0DQ== X-Google-Smtp-Source: AA0mqf7QL9UQbFuKnrPW9OkRE0J1CWHOajdAX/aDkBtAh9fn5jXDYNUL6v4eYSB2eCZav4hdnksQjw== X-Received: by 2002:a05:600c:3c96:b0:3cf:a457:2d89 with SMTP id bg22-20020a05600c3c9600b003cfa4572d89mr64426497wmb.20.1670321551191; Tue, 06 Dec 2022 02:12:31 -0800 (PST) Received: from platypus.lan ([2001:861:5e4c:3bb0:6424:328a:1734:3249]) by smtp.googlemail.com with ESMTPSA id r10-20020a05600c458a00b003cfd4a50d5asm27052699wmo.34.2022.12.06.02.12.30 (version=TLS1_3 cipher=TLS_AES_256_GCM_SHA384 bits=256/256); Tue, 06 Dec 2022 02:12:30 -0800 (PST) From: arthur.cohen@embecosm.com To: gcc-patches@gcc.gnu.org Cc: gcc-rust@gcc.gnu.org, Philip Herron Subject: [PATCH Rust front-end v4 29/46] gccrs: Add remaining type system transformations Date: Tue, 6 Dec 2022 11:14:01 +0100 Message-Id: <20221206101417.778807-30-arthur.cohen@embecosm.com> X-Mailer: git-send-email 2.38.1 In-Reply-To: <20221206101417.778807-1-arthur.cohen@embecosm.com> References: <20221206101417.778807-1-arthur.cohen@embecosm.com> MIME-Version: 1.0 X-Spam-Status: No, score=-18.2 required=5.0 tests=BAYES_00, DKIM_SIGNED, DKIM_VALID, DKIM_VALID_AU, DKIM_VALID_EF, GIT_PATCH_0, KAM_LOTSOFHASH, KAM_SHORT, RCVD_IN_DNSWL_NONE, SPF_HELO_NONE, SPF_PASS, TXREP autolearn=unavailable autolearn_force=no version=3.4.6 X-Spam-Checker-Version: SpamAssassin 3.4.6 (2021-04-09) on server2.sourceware.org X-BeenThere: gcc-patches@gcc.gnu.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: Gcc-patches mailing list List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: arthur.cohen@embecosm.com Errors-To: gcc-patches-bounces+patchwork=sourceware.org@gcc.gnu.org Sender: "Gcc-patches" From: Philip Herron This patch implements multiple transformation performed on the HIR during type-resolution such as type coercion, casts, auto-dereferencement. --- gcc/rust/typecheck/rust-autoderef.cc | 398 +++++ gcc/rust/typecheck/rust-autoderef.h | 178 ++ gcc/rust/typecheck/rust-casts.cc | 292 +++ gcc/rust/typecheck/rust-casts.h | 53 + gcc/rust/typecheck/rust-coercion.cc | 357 ++++ gcc/rust/typecheck/rust-coercion.h | 93 + gcc/rust/typecheck/rust-hir-dot-operator.cc | 263 +++ gcc/rust/typecheck/rust-hir-dot-operator.h | 81 + .../rust-hir-inherent-impl-overlap.h | 186 ++ gcc/rust/typecheck/rust-hir-path-probe.h | 540 ++++++ gcc/rust/typecheck/rust-hir-trait-ref.h | 472 +++++ gcc/rust/typecheck/rust-hir-type-bounds.h | 77 + .../typecheck/rust-substitution-mapper.cc | 77 + gcc/rust/typecheck/rust-substitution-mapper.h | 394 ++++ gcc/rust/typecheck/rust-tycheck-dump.h | 239 +++ gcc/rust/typecheck/rust-tyctx.cc | 155 ++ gcc/rust/typecheck/rust-tyty-bounds.cc | 462 +++++ gcc/rust/typecheck/rust-tyty-call.cc | 263 +++ gcc/rust/typecheck/rust-tyty-call.h | 147 ++ gcc/rust/typecheck/rust-tyty-cmp.h | 1554 ++++++++++++++++ gcc/rust/typecheck/rust-tyty-rules.h | 1584 +++++++++++++++++ 21 files changed, 7865 insertions(+) create mode 100644 gcc/rust/typecheck/rust-autoderef.cc create mode 100644 gcc/rust/typecheck/rust-autoderef.h create mode 100644 gcc/rust/typecheck/rust-casts.cc create mode 100644 gcc/rust/typecheck/rust-casts.h create mode 100644 gcc/rust/typecheck/rust-coercion.cc create mode 100644 gcc/rust/typecheck/rust-coercion.h create mode 100644 gcc/rust/typecheck/rust-hir-dot-operator.cc create mode 100644 gcc/rust/typecheck/rust-hir-dot-operator.h create mode 100644 gcc/rust/typecheck/rust-hir-inherent-impl-overlap.h create mode 100644 gcc/rust/typecheck/rust-hir-path-probe.h create mode 100644 gcc/rust/typecheck/rust-hir-trait-ref.h create mode 100644 gcc/rust/typecheck/rust-hir-type-bounds.h create mode 100644 gcc/rust/typecheck/rust-substitution-mapper.cc create mode 100644 gcc/rust/typecheck/rust-substitution-mapper.h create mode 100644 gcc/rust/typecheck/rust-tycheck-dump.h create mode 100644 gcc/rust/typecheck/rust-tyctx.cc create mode 100644 gcc/rust/typecheck/rust-tyty-bounds.cc create mode 100644 gcc/rust/typecheck/rust-tyty-call.cc create mode 100644 gcc/rust/typecheck/rust-tyty-call.h create mode 100644 gcc/rust/typecheck/rust-tyty-cmp.h create mode 100644 gcc/rust/typecheck/rust-tyty-rules.h diff --git a/gcc/rust/typecheck/rust-autoderef.cc b/gcc/rust/typecheck/rust-autoderef.cc new file mode 100644 index 00000000000..423f8e4709b --- /dev/null +++ b/gcc/rust/typecheck/rust-autoderef.cc @@ -0,0 +1,398 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-autoderef.h" +#include "rust-hir-path-probe.h" +#include "rust-hir-dot-operator.h" +#include "rust-hir-trait-resolve.h" + +namespace Rust { +namespace Resolver { + +static bool +resolve_operator_overload_fn ( + Analysis::RustLangItem::ItemType lang_item_type, const TyTy::BaseType *ty, + TyTy::FnType **resolved_fn, HIR::ImplItem **impl_item, + Adjustment::AdjustmentType *requires_ref_adjustment); + +TyTy::BaseType * +Adjuster::adjust_type (const std::vector &adjustments) +{ + if (adjustments.size () == 0) + return base->clone (); + + return adjustments.back ().get_expected ()->clone (); +} + +Adjustment +Adjuster::try_deref_type (const TyTy::BaseType *ty, + Analysis::RustLangItem::ItemType deref_lang_item) +{ + HIR::ImplItem *impl_item = nullptr; + TyTy::FnType *fn = nullptr; + Adjustment::AdjustmentType requires_ref_adjustment + = Adjustment::AdjustmentType::ERROR; + bool operator_overloaded + = resolve_operator_overload_fn (deref_lang_item, ty, &fn, &impl_item, + &requires_ref_adjustment); + if (!operator_overloaded) + { + return Adjustment::get_error (); + } + + auto resolved_base = fn->get_return_type ()->clone (); + bool is_valid_type = resolved_base->get_kind () == TyTy::TypeKind::REF; + if (!is_valid_type) + return Adjustment::get_error (); + + TyTy::ReferenceType *ref_base + = static_cast (resolved_base); + + Adjustment::AdjustmentType adjustment_type + = Adjustment::AdjustmentType::ERROR; + switch (deref_lang_item) + { + case Analysis::RustLangItem::ItemType::DEREF: + adjustment_type = Adjustment::AdjustmentType::DEREF; + break; + + case Analysis::RustLangItem::ItemType::DEREF_MUT: + adjustment_type = Adjustment::AdjustmentType::DEREF_MUT; + break; + + default: + break; + } + + return Adjustment::get_op_overload_deref_adjustment (adjustment_type, ty, + ref_base, fn, impl_item, + requires_ref_adjustment); +} + +Adjustment +Adjuster::try_raw_deref_type (const TyTy::BaseType *ty) +{ + bool is_valid_type = ty->get_kind () == TyTy::TypeKind::REF; + if (!is_valid_type) + return Adjustment::get_error (); + + const TyTy::ReferenceType *ref_base + = static_cast (ty); + auto infered = ref_base->get_base ()->clone (); + + return Adjustment (Adjustment::AdjustmentType::INDIRECTION, ty, infered); +} + +Adjustment +Adjuster::try_unsize_type (const TyTy::BaseType *ty) +{ + bool is_valid_type = ty->get_kind () == TyTy::TypeKind::ARRAY; + if (!is_valid_type) + return Adjustment::get_error (); + + auto mappings = Analysis::Mappings::get (); + auto context = TypeCheckContext::get (); + + const auto ref_base = static_cast (ty); + auto slice_elem = ref_base->get_element_type (); + + auto slice + = new TyTy::SliceType (mappings->get_next_hir_id (), ty->get_ident ().locus, + TyTy::TyVar (slice_elem->get_ref ())); + context->insert_implicit_type (slice); + + return Adjustment (Adjustment::AdjustmentType::UNSIZE, ty, slice); +} + +static bool +resolve_operator_overload_fn ( + Analysis::RustLangItem::ItemType lang_item_type, const TyTy::BaseType *ty, + TyTy::FnType **resolved_fn, HIR::ImplItem **impl_item, + Adjustment::AdjustmentType *requires_ref_adjustment) +{ + auto context = TypeCheckContext::get (); + auto mappings = Analysis::Mappings::get (); + + // look up lang item for arithmetic type + std::string associated_item_name + = Analysis::RustLangItem::ToString (lang_item_type); + DefId respective_lang_item_id = UNKNOWN_DEFID; + bool lang_item_defined + = mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id); + + if (!lang_item_defined) + return false; + + auto segment = HIR::PathIdentSegment (associated_item_name); + auto candidate + = MethodResolver::Probe (ty, HIR::PathIdentSegment (associated_item_name), + true); + + bool have_implementation_for_lang_item = !candidate.is_error (); + if (!have_implementation_for_lang_item) + return false; + + // Get the adjusted self + Adjuster adj (ty); + TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments); + + // is this the case we are recursive + // handle the case where we are within the impl block for this + // lang_item otherwise we end up with a recursive operator overload + // such as the i32 operator overload trait + TypeCheckContextItem &fn_context = context->peek_context (); + if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM) + { + auto &impl_item = fn_context.get_impl_item (); + HIR::ImplBlock *parent = impl_item.first; + HIR::Function *fn = impl_item.second; + + if (parent->has_trait_ref () + && fn->get_function_name ().compare (associated_item_name) == 0) + { + TraitReference *trait_reference + = TraitResolver::Lookup (*parent->get_trait_ref ().get ()); + if (!trait_reference->is_error ()) + { + TyTy::BaseType *lookup = nullptr; + bool ok = context->lookup_type (fn->get_mappings ().get_hirid (), + &lookup); + rust_assert (ok); + rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF); + + TyTy::FnType *fntype = static_cast (lookup); + rust_assert (fntype->is_method ()); + + bool is_lang_item_impl + = trait_reference->get_mappings ().get_defid () + == respective_lang_item_id; + bool self_is_lang_item_self + = fntype->get_self_type ()->is_equal (*adjusted_self); + bool recursive_operator_overload + = is_lang_item_impl && self_is_lang_item_self; + + if (recursive_operator_overload) + return false; + } + } + } + + TyTy::BaseType *lookup_tyty = candidate.candidate.ty; + + // rust only support impl item deref operator overloading ie you must have an + // impl block for it + rust_assert (candidate.candidate.type + == PathProbeCandidate::CandidateType::IMPL_FUNC); + *impl_item = candidate.candidate.item.impl.impl_item; + + rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF); + TyTy::BaseType *lookup = lookup_tyty; + TyTy::FnType *fn = static_cast (lookup); + rust_assert (fn->is_method ()); + + if (fn->needs_substitution ()) + { + if (ty->get_kind () == TyTy::TypeKind::ADT) + { + const TyTy::ADTType *adt = static_cast (ty); + + auto s = fn->get_self_type ()->get_root (); + rust_assert (s->can_eq (adt, false)); + rust_assert (s->get_kind () == TyTy::TypeKind::ADT); + const TyTy::ADTType *self_adt + = static_cast (s); + + // we need to grab the Self substitutions as the inherit type + // parameters for this + if (self_adt->needs_substitution ()) + { + rust_assert (adt->was_substituted ()); + + TyTy::SubstitutionArgumentMappings used_args_in_prev_segment + = GetUsedSubstArgs::From (adt); + + TyTy::SubstitutionArgumentMappings inherit_type_args + = self_adt->solve_mappings_from_receiver_for_self ( + used_args_in_prev_segment); + + // there may or may not be inherited type arguments + if (!inherit_type_args.is_error ()) + { + // need to apply the inherited type arguments to the + // function + lookup = fn->handle_substitions (inherit_type_args); + } + } + } + else + { + rust_assert (candidate.adjustments.size () < 2); + + // lets infer the params for this we could probably fix this up by + // actually just performing a substitution of a single param but this + // seems more generic i think. + // + // this is the case where we had say Foo<&Bar>> and we have derefed to + // the &Bar and we are trying to match a method self of Bar which + // requires another deref which is matched to the deref trait impl of + // &&T so this requires another reference and deref call + + lookup = fn->infer_substitions (Location ()); + rust_assert (lookup->get_kind () == TyTy::TypeKind::FNDEF); + fn = static_cast (lookup); + fn->get_self_type ()->unify (adjusted_self); + lookup = fn; + } + } + + if (candidate.adjustments.size () > 0) + *requires_ref_adjustment = candidate.adjustments.at (0).get_type (); + + *resolved_fn = static_cast (lookup); + + return true; +} + +AutoderefCycle::AutoderefCycle (bool autoderef_flag) + : autoderef_flag (autoderef_flag) +{} + +AutoderefCycle::~AutoderefCycle () {} + +void +AutoderefCycle::try_hook (const TyTy::BaseType &) +{} + +bool +AutoderefCycle::cycle (const TyTy::BaseType *receiver) +{ + const TyTy::BaseType *r = receiver; + while (true) + { + if (try_autoderefed (r)) + return true; + + // 4. deref to to 1, if cannot deref then quit + if (autoderef_flag) + return false; + + // try unsize + Adjustment unsize = Adjuster::try_unsize_type (r); + if (!unsize.is_error ()) + { + adjustments.push_back (unsize); + auto unsize_r = unsize.get_expected (); + + if (try_autoderefed (unsize_r)) + return true; + + adjustments.pop_back (); + } + + Adjustment deref + = Adjuster::try_deref_type (r, Analysis::RustLangItem::ItemType::DEREF); + if (!deref.is_error ()) + { + auto deref_r = deref.get_expected (); + adjustments.push_back (deref); + + if (try_autoderefed (deref_r)) + return true; + + adjustments.pop_back (); + } + + Adjustment deref_mut = Adjuster::try_deref_type ( + r, Analysis::RustLangItem::ItemType::DEREF_MUT); + if (!deref_mut.is_error ()) + { + auto deref_r = deref_mut.get_expected (); + adjustments.push_back (deref_mut); + + if (try_autoderefed (deref_r)) + return true; + + adjustments.pop_back (); + } + + if (!deref_mut.is_error ()) + { + auto deref_r = deref_mut.get_expected (); + adjustments.push_back (deref_mut); + Adjustment raw_deref = Adjuster::try_raw_deref_type (deref_r); + adjustments.push_back (raw_deref); + deref_r = raw_deref.get_expected (); + + if (try_autoderefed (deref_r)) + return true; + + adjustments.pop_back (); + adjustments.pop_back (); + } + + if (!deref.is_error ()) + { + r = deref.get_expected (); + adjustments.push_back (deref); + } + Adjustment raw_deref = Adjuster::try_raw_deref_type (r); + if (raw_deref.is_error ()) + return false; + + r = raw_deref.get_expected (); + adjustments.push_back (raw_deref); + } + return false; +} + +bool +AutoderefCycle::try_autoderefed (const TyTy::BaseType *r) +{ + try_hook (*r); + + // 1. try raw + if (select (*r)) + return true; + + // 2. try ref + TyTy::ReferenceType *r1 + = new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()), + Mutability::Imm); + adjustments.push_back ( + Adjustment (Adjustment::AdjustmentType::IMM_REF, r, r1)); + if (select (*r1)) + return true; + + adjustments.pop_back (); + + // 3. try mut ref + TyTy::ReferenceType *r2 + = new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()), + Mutability::Mut); + adjustments.push_back ( + Adjustment (Adjustment::AdjustmentType::MUT_REF, r, r2)); + if (select (*r2)) + return true; + + adjustments.pop_back (); + + return false; +} + +} // namespace Resolver +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-autoderef.h b/gcc/rust/typecheck/rust-autoderef.h new file mode 100644 index 00000000000..2f8d64b97e6 --- /dev/null +++ b/gcc/rust/typecheck/rust-autoderef.h @@ -0,0 +1,178 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_AUTODEREF +#define RUST_AUTODEREF + +#include "rust-tyty.h" + +namespace Rust { +namespace Resolver { + +class Adjustment +{ +public: + enum AdjustmentType + { + ERROR, + + IMM_REF, + MUT_REF, + DEREF, + DEREF_MUT, + INDIRECTION, + UNSIZE, + }; + + // ctor for all adjustments except derefs + Adjustment (AdjustmentType type, const TyTy::BaseType *actual, + const TyTy::BaseType *expected) + : Adjustment (type, actual, expected, nullptr, nullptr, + AdjustmentType::ERROR) + {} + + static Adjustment get_op_overload_deref_adjustment ( + AdjustmentType type, const TyTy::BaseType *actual, + const TyTy::BaseType *expected, TyTy::FnType *fn, HIR::ImplItem *deref_item, + Adjustment::AdjustmentType requires_ref_adjustment) + { + rust_assert (type == DEREF || type == DEREF_MUT); + return Adjustment (type, actual, expected, fn, deref_item, + requires_ref_adjustment); + } + + AdjustmentType get_type () const { return type; } + + const TyTy::BaseType *get_actual () const { return actual; } + const TyTy::BaseType *get_expected () const { return expected; } + + std::string as_string () const + { + return Adjustment::type_string (get_type ()) + "->" + + get_expected ()->debug_str (); + } + + static std::string type_string (AdjustmentType type) + { + switch (type) + { + case AdjustmentType::ERROR: + return "ERROR"; + case AdjustmentType::IMM_REF: + return "IMM_REF"; + case AdjustmentType::MUT_REF: + return "MUT_REF"; + case AdjustmentType::DEREF: + return "DEREF"; + case AdjustmentType::DEREF_MUT: + return "DEREF_MUT"; + case AdjustmentType::INDIRECTION: + return "INDIRECTION"; + case AdjustmentType::UNSIZE: + return "UNSIZE"; + } + gcc_unreachable (); + return ""; + } + + static Adjustment get_error () { return Adjustment{ERROR, nullptr, nullptr}; } + + bool is_error () const { return type == ERROR; } + + bool is_deref_adjustment () const { return type == DEREF; } + + bool is_deref_mut_adjustment () const { return type == DEREF_MUT; } + + bool has_operator_overload () const { return deref_operator_fn != nullptr; } + + TyTy::FnType *get_deref_operator_fn () const { return deref_operator_fn; } + + AdjustmentType get_deref_adjustment_type () const + { + return requires_ref_adjustment; + } + + HIR::ImplItem *get_deref_hir_item () const { return deref_item; } + +private: + Adjustment (AdjustmentType type, const TyTy::BaseType *actual, + const TyTy::BaseType *expected, TyTy::FnType *deref_operator_fn, + HIR::ImplItem *deref_item, + Adjustment::AdjustmentType requires_ref_adjustment) + : type (type), actual (actual), expected (expected), + deref_operator_fn (deref_operator_fn), deref_item (deref_item), + requires_ref_adjustment (requires_ref_adjustment) + {} + + AdjustmentType type; + const TyTy::BaseType *actual; + const TyTy::BaseType *expected; + + // - only used for deref operator_overloads + // + // the fn that we are calling + TyTy::FnType *deref_operator_fn; + HIR::ImplItem *deref_item; + // operator overloads can requre a reference + Adjustment::AdjustmentType requires_ref_adjustment; +}; + +class Adjuster +{ +public: + Adjuster (const TyTy::BaseType *ty) : base (ty) {} + + TyTy::BaseType *adjust_type (const std::vector &adjustments); + + static Adjustment + try_deref_type (const TyTy::BaseType *ty, + Analysis::RustLangItem::ItemType deref_lang_item); + + static Adjustment try_raw_deref_type (const TyTy::BaseType *ty); + + static Adjustment try_unsize_type (const TyTy::BaseType *ty); + +private: + const TyTy::BaseType *base; +}; + +class AutoderefCycle +{ +protected: + AutoderefCycle (bool autoderef_flag); + + virtual ~AutoderefCycle (); + + virtual bool select (const TyTy::BaseType &autoderefed) = 0; + + // optional: this is a chance to hook in to grab predicate items on the raw + // type + virtual void try_hook (const TyTy::BaseType &); + + virtual bool cycle (const TyTy::BaseType *receiver); + + bool try_autoderefed (const TyTy::BaseType *r); + + bool autoderef_flag; + std::vector adjustments; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_AUTODEREF diff --git a/gcc/rust/typecheck/rust-casts.cc b/gcc/rust/typecheck/rust-casts.cc new file mode 100644 index 00000000000..61004dfabc3 --- /dev/null +++ b/gcc/rust/typecheck/rust-casts.cc @@ -0,0 +1,292 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-casts.h" + +namespace Rust { +namespace Resolver { + +TypeCastRules::TypeCastRules (Location locus, TyTy::TyWithLocation from, + TyTy::TyWithLocation to) + : locus (locus), from (from), to (to) +{} + +TypeCoercionRules::CoercionResult +TypeCastRules::resolve (Location locus, TyTy::TyWithLocation from, + TyTy::TyWithLocation to) +{ + TypeCastRules cast_rules (locus, from, to); + return cast_rules.check (); +} + +TypeCoercionRules::CoercionResult +TypeCastRules::check () +{ + // https://github.com/rust-lang/rust/blob/7eac88abb2e57e752f3302f02be5f3ce3d7adfb4/compiler/rustc_typeck/src/check/cast.rs#L565-L582 + auto possible_coercion + = TypeCoercionRules::TryCoerce (from.get_ty (), to.get_ty (), locus); + if (!possible_coercion.is_error ()) + return possible_coercion; + + // try the simple cast rules + auto simple_cast = cast_rules (); + if (!simple_cast.is_error ()) + return simple_cast; + + // failed to cast + emit_cast_error (); + return TypeCoercionRules::CoercionResult::get_error (); +} + +TypeCoercionRules::CoercionResult +TypeCastRules::cast_rules () +{ + // https://github.com/rust-lang/rust/blob/7eac88abb2e57e752f3302f02be5f3ce3d7adfb4/compiler/rustc_typeck/src/check/cast.rs#L596 + // https://github.com/rust-lang/rust/blob/7eac88abb2e57e752f3302f02be5f3ce3d7adfb4/compiler/rustc_typeck/src/check/cast.rs#L654 + + rust_debug ("cast_rules from={%s} to={%s}", + from.get_ty ()->debug_str ().c_str (), + to.get_ty ()->debug_str ().c_str ()); + + switch (from.get_ty ()->get_kind ()) + { + case TyTy::TypeKind::INFER: { + TyTy::InferType *from_infer + = static_cast (from.get_ty ()); + switch (from_infer->get_infer_kind ()) + { + case TyTy::InferType::InferTypeKind::GENERAL: + return TypeCoercionRules::CoercionResult{{}, + to.get_ty ()->clone ()}; + + case TyTy::InferType::InferTypeKind::INTEGRAL: + switch (to.get_ty ()->get_kind ()) + { + case TyTy::TypeKind::CHAR: + case TyTy::TypeKind::BOOL: + case TyTy::TypeKind::USIZE: + case TyTy::TypeKind::ISIZE: + case TyTy::TypeKind::UINT: + case TyTy::TypeKind::INT: + case TyTy::TypeKind::POINTER: + return TypeCoercionRules::CoercionResult{ + {}, to.get_ty ()->clone ()}; + + case TyTy::TypeKind::INFER: { + TyTy::InferType *to_infer + = static_cast (to.get_ty ()); + + switch (to_infer->get_infer_kind ()) + { + case TyTy::InferType::InferTypeKind::GENERAL: + case TyTy::InferType::InferTypeKind::INTEGRAL: + return TypeCoercionRules::CoercionResult{ + {}, to.get_ty ()->clone ()}; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + } + break; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + break; + + case TyTy::InferType::InferTypeKind::FLOAT: + switch (to.get_ty ()->get_kind ()) + { + case TyTy::TypeKind::USIZE: + case TyTy::TypeKind::ISIZE: + case TyTy::TypeKind::UINT: + case TyTy::TypeKind::INT: + return TypeCoercionRules::CoercionResult{ + {}, to.get_ty ()->clone ()}; + + case TyTy::TypeKind::INFER: { + TyTy::InferType *to_infer + = static_cast (to.get_ty ()); + + switch (to_infer->get_infer_kind ()) + { + case TyTy::InferType::InferTypeKind::GENERAL: + case TyTy::InferType::InferTypeKind::FLOAT: + return TypeCoercionRules::CoercionResult{ + {}, to.get_ty ()->clone ()}; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + } + break; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + break; + } + } + break; + + case TyTy::TypeKind::BOOL: + switch (to.get_ty ()->get_kind ()) + { + case TyTy::TypeKind::INFER: + case TyTy::TypeKind::USIZE: + case TyTy::TypeKind::ISIZE: + case TyTy::TypeKind::UINT: + case TyTy::TypeKind::INT: + return TypeCoercionRules::CoercionResult{{}, to.get_ty ()->clone ()}; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + break; + + case TyTy::TypeKind::CHAR: + case TyTy::TypeKind::USIZE: + case TyTy::TypeKind::ISIZE: + case TyTy::TypeKind::UINT: + case TyTy::TypeKind::INT: + switch (to.get_ty ()->get_kind ()) + { + case TyTy::TypeKind::CHAR: { + // only u8 and char + bool was_uint = from.get_ty ()->get_kind () == TyTy::TypeKind::UINT; + bool was_u8 = was_uint + && (static_cast (from.get_ty ()) + ->get_uint_kind () + == TyTy::UintType::UintKind::U8); + if (was_u8) + return TypeCoercionRules::CoercionResult{{}, + to.get_ty ()->clone ()}; + } + break; + + case TyTy::TypeKind::INFER: + case TyTy::TypeKind::USIZE: + case TyTy::TypeKind::ISIZE: + case TyTy::TypeKind::UINT: + case TyTy::TypeKind::INT: + return TypeCoercionRules::CoercionResult{{}, to.get_ty ()->clone ()}; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + break; + + case TyTy::TypeKind::FLOAT: + switch (to.get_ty ()->get_kind ()) + { + case TyTy::TypeKind::FLOAT: + return TypeCoercionRules::CoercionResult{{}, to.get_ty ()->clone ()}; + + case TyTy::TypeKind::INFER: { + TyTy::InferType *to_infer + = static_cast (to.get_ty ()); + + switch (to_infer->get_infer_kind ()) + { + case TyTy::InferType::InferTypeKind::GENERAL: + case TyTy::InferType::InferTypeKind::FLOAT: + return TypeCoercionRules::CoercionResult{ + {}, to.get_ty ()->clone ()}; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + } + break; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + break; + + case TyTy::TypeKind::REF: + case TyTy::TypeKind::POINTER: + switch (to.get_ty ()->get_kind ()) + { + case TyTy::TypeKind::REF: + case TyTy::TypeKind::POINTER: + return check_ptr_ptr_cast (); + + // FIXME can you cast a pointer to a integral type? + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + break; + + default: + return TypeCoercionRules::CoercionResult::get_error (); + } + + return TypeCoercionRules::CoercionResult::get_error (); +} + +TypeCoercionRules::CoercionResult +TypeCastRules::check_ptr_ptr_cast () +{ + rust_debug ("check_ptr_ptr_cast from={%s} to={%s}", + from.get_ty ()->debug_str ().c_str (), + to.get_ty ()->debug_str ().c_str ()); + + bool from_is_ref = from.get_ty ()->get_kind () == TyTy::TypeKind::REF; + bool to_is_ref = to.get_ty ()->get_kind () == TyTy::TypeKind::REF; + bool from_is_ptr = from.get_ty ()->get_kind () == TyTy::TypeKind::POINTER; + bool to_is_ptr = to.get_ty ()->get_kind () == TyTy::TypeKind::POINTER; + + if (from_is_ptr && to_is_ptr) + { + // mutability is ignored here as all pointer usage requires unsafe + return TypeCoercionRules::CoercionResult{{}, to.get_ty ()->clone ()}; + } + else if (from_is_ref && to_is_ref) + { + // mutability must be coercedable + TyTy::ReferenceType &f + = static_cast (*from.get_ty ()); + TyTy::ReferenceType &t + = static_cast (*to.get_ty ()); + + if (TypeCoercionRules::coerceable_mutability (f.mutability (), + t.mutability ())) + { + return TypeCoercionRules::CoercionResult{{}, to.get_ty ()->clone ()}; + } + } + + return TypeCoercionRules::CoercionResult::get_error (); +} + +void +TypeCastRules::emit_cast_error () const +{ + // error[E0604] + RichLocation r (locus); + r.add_range (from.get_locus ()); + r.add_range (to.get_locus ()); + rust_error_at (r, "invalid cast %<%s%> to %<%s%>", + from.get_ty ()->get_name ().c_str (), + to.get_ty ()->get_name ().c_str ()); +} + +} // namespace Resolver +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-casts.h b/gcc/rust/typecheck/rust-casts.h new file mode 100644 index 00000000000..e908f49b656 --- /dev/null +++ b/gcc/rust/typecheck/rust-casts.h @@ -0,0 +1,53 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_CASTS +#define RUST_CASTS + +#include "rust-tyty.h" +#include "rust-coercion.h" + +namespace Rust { +namespace Resolver { + +class TypeCastRules +{ +public: + static TypeCoercionRules::CoercionResult + resolve (Location locus, TyTy::TyWithLocation from, TyTy::TyWithLocation to); + +protected: + TypeCoercionRules::CoercionResult check (); + TypeCoercionRules::CoercionResult cast_rules (); + TypeCoercionRules::CoercionResult check_ptr_ptr_cast (); + + void emit_cast_error () const; + +protected: + TypeCastRules (Location locus, TyTy::TyWithLocation from, + TyTy::TyWithLocation to); + + Location locus; + TyTy::TyWithLocation from; + TyTy::TyWithLocation to; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_CASTS diff --git a/gcc/rust/typecheck/rust-coercion.cc b/gcc/rust/typecheck/rust-coercion.cc new file mode 100644 index 00000000000..2ad2b8007ff --- /dev/null +++ b/gcc/rust/typecheck/rust-coercion.cc @@ -0,0 +1,357 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-coercion.h" + +namespace Rust { +namespace Resolver { + +TypeCoercionRules::CoercionResult +TypeCoercionRules::Coerce (TyTy::BaseType *receiver, TyTy::BaseType *expected, + Location locus) +{ + TypeCoercionRules resolver (expected, locus, true); + bool ok = resolver.do_coercion (receiver); + return ok ? resolver.try_result : CoercionResult::get_error (); +} + +TypeCoercionRules::CoercionResult +TypeCoercionRules::TryCoerce (TyTy::BaseType *receiver, + TyTy::BaseType *expected, Location locus) +{ + TypeCoercionRules resolver (expected, locus, false); + bool ok = resolver.do_coercion (receiver); + return ok ? resolver.try_result : CoercionResult::get_error (); +} + +TypeCoercionRules::TypeCoercionRules (TyTy::BaseType *expected, Location locus, + bool emit_errors) + : AutoderefCycle (false), mappings (Analysis::Mappings::get ()), + context (TypeCheckContext::get ()), expected (expected), locus (locus), + try_result (CoercionResult::get_error ()), emit_errors (emit_errors) +{} + +bool +TypeCoercionRules::do_coercion (TyTy::BaseType *receiver) +{ + // FIXME this is not finished and might be super simplified + // see: + // https://github.com/rust-lang/rust/blob/7eac88abb2e57e752f3302f02be5f3ce3d7adfb4/compiler/rustc_typeck/src/check/coercion.rs + + // unsize + bool unsafe_error = false; + CoercionResult unsize_coercion + = coerce_unsized (receiver, expected, unsafe_error); + bool valid_unsize_coercion = !unsize_coercion.is_error (); + if (valid_unsize_coercion) + { + try_result = unsize_coercion; + return true; + } + else if (unsafe_error) + { + // Location lhs = mappings->lookup_location (receiver->get_ref ()); + // Location rhs = mappings->lookup_location (expected->get_ref ()); + // object_unsafe_error (locus, lhs, rhs); + return false; + } + + // pointers + switch (expected->get_kind ()) + { + case TyTy::TypeKind::POINTER: { + TyTy::PointerType *ptr = static_cast (expected); + try_result = coerce_unsafe_ptr (receiver, ptr, ptr->mutability ()); + return !try_result.is_error (); + } + + case TyTy::TypeKind::REF: { + TyTy::ReferenceType *ptr + = static_cast (expected); + try_result + = coerce_borrowed_pointer (receiver, ptr, ptr->mutability ()); + return !try_result.is_error (); + } + break; + + default: + break; + } + + return !try_result.is_error (); +} + +TypeCoercionRules::CoercionResult +TypeCoercionRules::coerce_unsafe_ptr (TyTy::BaseType *receiver, + TyTy::PointerType *expected, + Mutability to_mutbl) +{ + rust_debug ("coerce_unsafe_ptr(a={%s}, b={%s})", + receiver->debug_str ().c_str (), expected->debug_str ().c_str ()); + + Mutability from_mutbl = Mutability::Imm; + TyTy::BaseType *element = nullptr; + switch (receiver->get_kind ()) + { + case TyTy::TypeKind::REF: { + TyTy::ReferenceType *ref + = static_cast (receiver); + from_mutbl = ref->mutability (); + element = ref->get_base (); + } + break; + + case TyTy::TypeKind::POINTER: { + TyTy::PointerType *ref = static_cast (receiver); + from_mutbl = ref->mutability (); + element = ref->get_base (); + } + break; + + default: { + if (receiver->can_eq (expected, false)) + return CoercionResult{{}, expected->clone ()}; + + return CoercionResult::get_error (); + } + } + + if (!coerceable_mutability (from_mutbl, to_mutbl)) + { + Location lhs = mappings->lookup_location (receiver->get_ref ()); + Location rhs = mappings->lookup_location (expected->get_ref ()); + mismatched_mutability_error (locus, lhs, rhs); + return TypeCoercionRules::CoercionResult::get_error (); + } + + TyTy::PointerType *result + = new TyTy::PointerType (receiver->get_ref (), + TyTy::TyVar (element->get_ref ()), to_mutbl); + if (!result->can_eq (expected, false)) + return CoercionResult::get_error (); + + return CoercionResult{{}, result}; +} + +/// Reborrows `&mut A` to `&mut B` and `&(mut) A` to `&B`. +/// To match `A` with `B`, autoderef will be performed, +/// calling `deref`/`deref_mut` where necessary. +TypeCoercionRules::CoercionResult +TypeCoercionRules::coerce_borrowed_pointer (TyTy::BaseType *receiver, + TyTy::ReferenceType *expected, + Mutability to_mutbl) +{ + rust_debug ("coerce_borrowed_pointer(a={%s}, b={%s})", + receiver->debug_str ().c_str (), expected->debug_str ().c_str ()); + + Mutability from_mutbl = Mutability::Imm; + switch (receiver->get_kind ()) + { + case TyTy::TypeKind::REF: { + TyTy::ReferenceType *ref + = static_cast (receiver); + from_mutbl = ref->mutability (); + } + break; + + default: { + TyTy::BaseType *result = receiver->unify (expected); + return CoercionResult{{}, result}; + } + } + + if (!coerceable_mutability (from_mutbl, to_mutbl)) + { + Location lhs = mappings->lookup_location (receiver->get_ref ()); + Location rhs = mappings->lookup_location (expected->get_ref ()); + mismatched_mutability_error (locus, lhs, rhs); + return TypeCoercionRules::CoercionResult::get_error (); + } + + AutoderefCycle::cycle (receiver); + return try_result; +} + +// &[T; n] or &mut [T; n] -> &[T] +// or &mut [T; n] -> &mut [T] +// or &Concrete -> &Trait, etc. +TypeCoercionRules::CoercionResult +TypeCoercionRules::coerce_unsized (TyTy::BaseType *source, + TyTy::BaseType *target, bool &unsafe_error) +{ + rust_debug ("coerce_unsized(source={%s}, target={%s})", + source->debug_str ().c_str (), target->debug_str ().c_str ()); + + bool source_is_ref = source->get_kind () == TyTy::TypeKind::REF; + bool target_is_ref = target->get_kind () == TyTy::TypeKind::REF; + bool target_is_ptr = target->get_kind () == TyTy::TypeKind::POINTER; + + bool needs_reborrow = false; + TyTy::BaseType *ty_a = source; + TyTy::BaseType *ty_b = target; + Mutability expected_mutability = Mutability::Imm; + if (source_is_ref && target_is_ref) + { + TyTy::ReferenceType *source_ref + = static_cast (source); + TyTy::ReferenceType *target_ref + = static_cast (target); + + Mutability from_mutbl = source_ref->mutability (); + Mutability to_mutbl = target_ref->mutability (); + if (!coerceable_mutability (from_mutbl, to_mutbl)) + { + unsafe_error = true; + Location lhs = mappings->lookup_location (source->get_ref ()); + Location rhs = mappings->lookup_location (target->get_ref ()); + mismatched_mutability_error (locus, lhs, rhs); + return TypeCoercionRules::CoercionResult::get_error (); + } + + ty_a = source_ref->get_base (); + ty_b = target_ref->get_base (); + needs_reborrow = true; + expected_mutability = to_mutbl; + + adjustments.push_back ( + Adjustment (Adjustment::AdjustmentType::INDIRECTION, source_ref, ty_a)); + } + else if (source_is_ref && target_is_ptr) + { + TyTy::ReferenceType *source_ref + = static_cast (source); + TyTy::PointerType *target_ref = static_cast (target); + + Mutability from_mutbl = source_ref->mutability (); + Mutability to_mutbl = target_ref->mutability (); + if (!coerceable_mutability (from_mutbl, to_mutbl)) + { + unsafe_error = true; + Location lhs = mappings->lookup_location (source->get_ref ()); + Location rhs = mappings->lookup_location (target->get_ref ()); + mismatched_mutability_error (locus, lhs, rhs); + return TypeCoercionRules::CoercionResult::get_error (); + } + + ty_a = source_ref->get_base (); + ty_b = target_ref->get_base (); + needs_reborrow = true; + expected_mutability = to_mutbl; + + adjustments.push_back ( + Adjustment (Adjustment::AdjustmentType::INDIRECTION, source_ref, ty_a)); + } + + // FIXME + // there is a bunch of code to ensure something is coerce able to a dyn trait + // we need to support but we need to support a few more lang items for that + // see: + // https://github.com/rust-lang/rust/blob/7eac88abb2e57e752f3302f02be5f3ce3d7adfb4/compiler/rustc_typeck/src/check/coercion.rs#L582 + + const auto a = ty_a; + const auto b = ty_b; + + bool expect_dyn = b->get_kind () == TyTy::TypeKind::DYNAMIC; + bool need_unsize = a->get_kind () != TyTy::TypeKind::DYNAMIC; + + if (expect_dyn && need_unsize) + { + bool bounds_compatible = b->bounds_compatible (*a, locus, true); + if (!bounds_compatible) + { + unsafe_error = true; + return TypeCoercionRules::CoercionResult::get_error (); + } + + // return the unsize coercion + TyTy::BaseType *result = b->clone (); + // result->set_ref (a->get_ref ()); + + // append a dyn coercion adjustment + adjustments.push_back (Adjustment (Adjustment::UNSIZE, a, result)); + + // reborrow if needed + if (needs_reborrow) + { + TyTy::ReferenceType *reborrow + = new TyTy::ReferenceType (source->get_ref (), + TyTy::TyVar (result->get_ref ()), + expected_mutability); + + Adjustment::AdjustmentType borrow_type + = expected_mutability == Mutability::Imm ? Adjustment::IMM_REF + : Adjustment::MUT_REF; + adjustments.push_back (Adjustment (borrow_type, result, reborrow)); + result = reborrow; + } + + return CoercionResult{adjustments, result}; + } + + adjustments.clear (); + return TypeCoercionRules::CoercionResult::get_error (); +} + +bool +TypeCoercionRules::select (const TyTy::BaseType &autoderefed) +{ + if (autoderefed.can_eq (expected, false)) + { + try_result = CoercionResult{adjustments, autoderefed.clone ()}; + return true; + } + return false; +} + +/// Coercing a mutable reference to an immutable works, while +/// coercing `&T` to `&mut T` should be forbidden. +bool +TypeCoercionRules::coerceable_mutability (Mutability from_mutbl, + Mutability to_mutbl) +{ + return to_mutbl == Mutability::Imm || (from_mutbl == to_mutbl); +} + +void +TypeCoercionRules::mismatched_mutability_error (Location expr_locus, + Location lhs, Location rhs) +{ + if (!emit_errors) + return; + + RichLocation r (expr_locus); + r.add_range (lhs); + r.add_range (rhs); + rust_error_at (r, "mismatched mutability"); +} + +void +TypeCoercionRules::object_unsafe_error (Location expr_locus, Location lhs, + Location rhs) +{ + if (!emit_errors) + return; + + RichLocation r (expr_locus); + r.add_range (lhs); + r.add_range (rhs); + rust_error_at (r, "unsafe unsize coercion"); +} + +} // namespace Resolver +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-coercion.h b/gcc/rust/typecheck/rust-coercion.h new file mode 100644 index 00000000000..da28c7c5e1b --- /dev/null +++ b/gcc/rust/typecheck/rust-coercion.h @@ -0,0 +1,93 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_COERCION +#define RUST_COERCION + +#include "rust-autoderef.h" +#include "rust-hir-type-check.h" + +namespace Rust { +namespace Resolver { + +class TypeCoercionRules : protected AutoderefCycle +{ +public: + struct CoercionResult + { + std::vector adjustments; + TyTy::BaseType *tyty; + + bool is_error () + { + return tyty == nullptr || tyty->get_kind () == TyTy::TypeKind::ERROR; + } + + static CoercionResult get_error () { return CoercionResult{{}, nullptr}; } + }; + + static CoercionResult Coerce (TyTy::BaseType *receiver, + TyTy::BaseType *expected, Location locus); + + static CoercionResult TryCoerce (TyTy::BaseType *receiver, + TyTy::BaseType *expected, Location locus); + + CoercionResult coerce_unsafe_ptr (TyTy::BaseType *receiver, + TyTy::PointerType *expected, + Mutability mutability); + + CoercionResult coerce_borrowed_pointer (TyTy::BaseType *receiver, + TyTy::ReferenceType *expected, + Mutability mutability); + + CoercionResult coerce_unsized (TyTy::BaseType *receiver, + TyTy::BaseType *expected, bool &unsafe_error); + + static bool coerceable_mutability (Mutability from_mutbl, + Mutability to_mutbl); + + void mismatched_mutability_error (Location expr_locus, Location lhs, + Location rhs); + void object_unsafe_error (Location expr_locus, Location lhs, Location rhs); + +protected: + TypeCoercionRules (TyTy::BaseType *expected, Location locus, + bool emit_errors); + + bool select (const TyTy::BaseType &autoderefed) override; + + bool do_coercion (TyTy::BaseType *receiver); + +private: + // context info + Analysis::Mappings *mappings; + TypeCheckContext *context; + + // search + TyTy::BaseType *expected; + Location locus; + + // mutable fields + CoercionResult try_result; + bool emit_errors; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_COERCION diff --git a/gcc/rust/typecheck/rust-hir-dot-operator.cc b/gcc/rust/typecheck/rust-hir-dot-operator.cc new file mode 100644 index 00000000000..d45f0903478 --- /dev/null +++ b/gcc/rust/typecheck/rust-hir-dot-operator.cc @@ -0,0 +1,263 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-hir-dot-operator.h" +#include "rust-hir-path-probe.h" +#include "rust-hir-trait-resolve.h" + +namespace Rust { +namespace Resolver { + +MethodResolver::MethodResolver (bool autoderef_flag, + const HIR::PathIdentSegment &segment_name) + : AutoderefCycle (autoderef_flag), mappings (Analysis::Mappings::get ()), + context (TypeCheckContext::get ()), segment_name (segment_name), + try_result (MethodCandidate::get_error ()) +{} + +MethodCandidate +MethodResolver::Probe (const TyTy::BaseType *receiver, + const HIR::PathIdentSegment &segment_name, + bool autoderef_flag) +{ + MethodResolver resolver (autoderef_flag, segment_name); + bool ok = resolver.cycle (receiver); + return ok ? resolver.try_result : MethodCandidate::get_error (); +} + +void +MethodResolver::try_hook (const TyTy::BaseType &r) +{ + const auto &specified_bounds = r.get_specified_bounds (); + predicate_items = get_predicate_items (segment_name, r, specified_bounds); +} + +bool +MethodResolver::select (const TyTy::BaseType &receiver) +{ + struct impl_item_candidate + { + HIR::Function *item; + HIR::ImplBlock *impl_block; + TyTy::FnType *ty; + }; + + // assemble inherent impl items + std::vector inherent_impl_fns; + mappings->iterate_impl_items ( + [&] (HirId id, HIR::ImplItem *item, HIR::ImplBlock *impl) mutable -> bool { + bool is_trait_impl = impl->has_trait_ref (); + if (is_trait_impl) + return true; + + bool is_fn + = item->get_impl_item_type () == HIR::ImplItem::ImplItemType::FUNCTION; + if (!is_fn) + return true; + + HIR::Function *func = static_cast (item); + if (!func->is_method ()) + return true; + + bool name_matches + = func->get_function_name ().compare (segment_name.as_string ()) == 0; + if (!name_matches) + return true; + + TyTy::BaseType *ty = nullptr; + if (!context->lookup_type (func->get_mappings ().get_hirid (), &ty)) + return true; + if (ty->get_kind () == TyTy::TypeKind::ERROR) + return true; + + rust_assert (ty->get_kind () == TyTy::TypeKind::FNDEF); + TyTy::FnType *fnty = static_cast (ty); + + inherent_impl_fns.push_back ({func, impl, fnty}); + + return true; + }); + + struct trait_item_candidate + { + const HIR::TraitItemFunc *item; + const HIR::Trait *trait; + TyTy::FnType *ty; + const TraitReference *reference; + const TraitItemReference *item_ref; + }; + + std::vector trait_fns; + mappings->iterate_impl_blocks ([&] (HirId id, + HIR::ImplBlock *impl) mutable -> bool { + bool is_trait_impl = impl->has_trait_ref (); + if (!is_trait_impl) + return true; + + // look for impl implementation else lookup the associated trait item + for (auto &impl_item : impl->get_impl_items ()) + { + bool is_fn = impl_item->get_impl_item_type () + == HIR::ImplItem::ImplItemType::FUNCTION; + if (!is_fn) + continue; + + HIR::Function *func = static_cast (impl_item.get ()); + if (!func->is_method ()) + continue; + + bool name_matches + = func->get_function_name ().compare (segment_name.as_string ()) == 0; + if (!name_matches) + continue; + + TyTy::BaseType *ty = nullptr; + if (!context->lookup_type (func->get_mappings ().get_hirid (), &ty)) + continue; + if (ty->get_kind () == TyTy::TypeKind::ERROR) + continue; + + rust_assert (ty->get_kind () == TyTy::TypeKind::FNDEF); + TyTy::FnType *fnty = static_cast (ty); + + inherent_impl_fns.push_back ({func, impl, fnty}); + return true; + } + + TraitReference *trait_ref + = TraitResolver::Resolve (*impl->get_trait_ref ().get ()); + rust_assert (!trait_ref->is_error ()); + + auto item_ref + = trait_ref->lookup_trait_item (segment_name.as_string (), + TraitItemReference::TraitItemType::FN); + if (item_ref->is_error ()) + return true; + + const HIR::Trait *trait = trait_ref->get_hir_trait_ref (); + HIR::TraitItem *item = item_ref->get_hir_trait_item (); + rust_assert (item->get_item_kind () == HIR::TraitItem::TraitItemKind::FUNC); + HIR::TraitItemFunc *func = static_cast (item); + + TyTy::BaseType *ty = item_ref->get_tyty (); + rust_assert (ty->get_kind () == TyTy::TypeKind::FNDEF); + TyTy::FnType *fnty = static_cast (ty); + + trait_item_candidate candidate{func, trait, fnty, trait_ref, item_ref}; + trait_fns.push_back (candidate); + + return true; + }); + + // lookup specified bounds for an associated item + struct precdicate_candidate + { + TyTy::TypeBoundPredicateItem lookup; + TyTy::FnType *fntype; + }; + + for (auto impl_item : inherent_impl_fns) + { + TyTy::FnType *fn = impl_item.ty; + rust_assert (fn->is_method ()); + + TyTy::BaseType *fn_self = fn->get_self_type (); + if (fn_self->can_eq (&receiver, false)) + { + PathProbeCandidate::ImplItemCandidate c{impl_item.item, + impl_item.impl_block}; + try_result = MethodCandidate{ + PathProbeCandidate (PathProbeCandidate::CandidateType::IMPL_FUNC, + fn, impl_item.item->get_locus (), c), + adjustments}; + return true; + } + } + + for (auto trait_item : trait_fns) + { + TyTy::FnType *fn = trait_item.ty; + rust_assert (fn->is_method ()); + + TyTy::BaseType *fn_self = fn->get_self_type (); + if (fn_self->can_eq (&receiver, false)) + { + PathProbeCandidate::TraitItemCandidate c{trait_item.reference, + trait_item.item_ref, + nullptr}; + try_result = MethodCandidate{ + PathProbeCandidate (PathProbeCandidate::CandidateType::TRAIT_FUNC, + fn, trait_item.item->get_locus (), c), + adjustments}; + return true; + } + } + + for (const auto &predicate : predicate_items) + { + const TyTy::FnType *fn = predicate.fntype; + rust_assert (fn->is_method ()); + + TyTy::BaseType *fn_self = fn->get_self_type (); + if (fn_self->can_eq (&receiver, false)) + { + const TraitReference *trait_ref + = predicate.lookup.get_parent ()->get (); + const TraitItemReference *trait_item + = predicate.lookup.get_raw_item (); + + PathProbeCandidate::TraitItemCandidate c{trait_ref, trait_item, + nullptr}; + try_result = MethodCandidate{ + PathProbeCandidate (PathProbeCandidate::CandidateType::TRAIT_FUNC, + fn->clone (), trait_item->get_locus (), c), + adjustments}; + return true; + } + } + + return false; +} + +std::vector +MethodResolver::get_predicate_items ( + const HIR::PathIdentSegment &segment_name, const TyTy::BaseType &receiver, + const std::vector &specified_bounds) +{ + std::vector predicate_items; + for (auto &bound : specified_bounds) + { + TyTy::TypeBoundPredicateItem lookup + = bound.lookup_associated_item (segment_name.as_string ()); + if (lookup.is_error ()) + continue; + + TyTy::BaseType *ty = lookup.get_tyty_for_receiver (&receiver); + if (ty->get_kind () == TyTy::TypeKind::FNDEF) + { + TyTy::FnType *fnty = static_cast (ty); + predicate_candidate candidate{lookup, fnty}; + predicate_items.push_back (candidate); + } + } + + return predicate_items; +} + +} // namespace Resolver +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-hir-dot-operator.h b/gcc/rust/typecheck/rust-hir-dot-operator.h new file mode 100644 index 00000000000..750601a2d9e --- /dev/null +++ b/gcc/rust/typecheck/rust-hir-dot-operator.h @@ -0,0 +1,81 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_HIR_DOT_OPERATOR +#define RUST_HIR_DOT_OPERATOR + +#include "rust-hir-path-probe.h" + +namespace Rust { +namespace Resolver { + +struct MethodCandidate +{ + PathProbeCandidate candidate; + std::vector adjustments; + + static MethodCandidate get_error () + { + return {PathProbeCandidate::get_error (), {}}; + } + + bool is_error () const { return candidate.is_error (); } +}; + +class MethodResolver : protected AutoderefCycle +{ +public: + struct predicate_candidate + { + TyTy::TypeBoundPredicateItem lookup; + TyTy::FnType *fntype; + }; + + static MethodCandidate Probe (const TyTy::BaseType *receiver, + const HIR::PathIdentSegment &segment_name, + bool autoderef_flag = false); + + static std::vector get_predicate_items ( + const HIR::PathIdentSegment &segment_name, const TyTy::BaseType &receiver, + const std::vector &specified_bounds); + +protected: + MethodResolver (bool autoderef_flag, + const HIR::PathIdentSegment &segment_name); + + void try_hook (const TyTy::BaseType &r) override; + + bool select (const TyTy::BaseType &receiver) override; + +private: + // context info + Analysis::Mappings *mappings; + TypeCheckContext *context; + + // search + const HIR::PathIdentSegment &segment_name; + std::vector predicate_items; + + // mutable fields + MethodCandidate try_result; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_HIR_DOT_OPERATOR diff --git a/gcc/rust/typecheck/rust-hir-inherent-impl-overlap.h b/gcc/rust/typecheck/rust-hir-inherent-impl-overlap.h new file mode 100644 index 00000000000..2890b54a00d --- /dev/null +++ b/gcc/rust/typecheck/rust-hir-inherent-impl-overlap.h @@ -0,0 +1,186 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_HIR_INHERENT_IMPL_ITEM_OVERLAP_H +#define RUST_HIR_INHERENT_IMPL_ITEM_OVERLAP_H + +#include "rust-hir-type-check-base.h" +#include "rust-hir-full.h" + +namespace Rust { +namespace Resolver { + +class ImplItemToName : private TypeCheckBase, private HIR::HIRImplVisitor +{ +public: + static bool resolve (HIR::ImplItem *item, std::string &name_result) + { + ImplItemToName resolver (name_result); + item->accept_vis (resolver); + return resolver.ok; + } + + void visit (HIR::TypeAlias &alias) override + { + ok = true; + result.assign (alias.get_new_type_name ()); + } + + void visit (HIR::Function &function) override + { + ok = true; + result.assign (function.get_function_name ()); + } + + void visit (HIR::ConstantItem &constant) override + { + ok = true; + result.assign (constant.get_identifier ()); + } + +private: + ImplItemToName (std::string &result) + : TypeCheckBase (), ok (false), result (result) + {} + + bool ok; + std::string &result; +}; + +class OverlappingImplItemPass : public TypeCheckBase +{ +public: + static void go () + { + OverlappingImplItemPass pass; + + // generate mappings + pass.mappings->iterate_impl_items ( + [&] (HirId id, HIR::ImplItem *impl_item, HIR::ImplBlock *impl) -> bool { + // ignoring trait-impls might need thought later on + if (impl->has_trait_ref ()) + return true; + + pass.process_impl_item (id, impl_item, impl); + return true; + }); + + pass.scan (); + } + + void process_impl_item (HirId id, HIR::ImplItem *impl_item, + HIR::ImplBlock *impl) + { + // lets make a mapping of impl-item Self type to (impl-item,name): + // { + // impl-type -> [ (item, name), ... ] + // } + + HirId impl_type_id = impl->get_type ()->get_mappings ().get_hirid (); + TyTy::BaseType *impl_type = nullptr; + bool ok = context->lookup_type (impl_type_id, &impl_type); + rust_assert (ok); + + std::string impl_item_name; + ok = ImplItemToName::resolve (impl_item, impl_item_name); + rust_assert (ok); + + std::pair elem (impl_item, impl_item_name); + impl_mappings[impl_type].insert (std::move (elem)); + } + + void scan () + { + // we can now brute force the map looking for can_eq on each of the + // impl_items_types to look for possible colliding impl blocks; + for (auto it = impl_mappings.begin (); it != impl_mappings.end (); it++) + { + TyTy::BaseType *query = it->first; + + for (auto iy = impl_mappings.begin (); iy != impl_mappings.end (); iy++) + { + TyTy::BaseType *candidate = iy->first; + if (query == candidate) + continue; + + if (query->can_eq (candidate, false)) + { + // we might be in the case that we have: + // + // *const T vs *const [T] + // + // so lets use an equality check when the + // candidates are both generic to be sure we dont emit a false + // positive + + bool a = query->is_concrete (); + bool b = candidate->is_concrete (); + bool both_generic = !a && !b; + if (both_generic) + { + if (!query->is_equal (*candidate)) + continue; + } + + possible_collision (it->second, iy->second); + } + } + } + } + + void possible_collision ( + std::set > query, + std::set > candidate) + { + for (auto &q : query) + { + HIR::ImplItem *query_impl_item = q.first; + std::string query_impl_item_name = q.second; + + for (auto &c : candidate) + { + HIR::ImplItem *candidate_impl_item = c.first; + std::string candidate_impl_item_name = c.second; + + if (query_impl_item_name.compare (candidate_impl_item_name) == 0) + collision_detected (query_impl_item, candidate_impl_item, + candidate_impl_item_name); + } + } + } + + void collision_detected (HIR::ImplItem *query, HIR::ImplItem *dup, + const std::string &name) + { + RichLocation r (dup->get_locus ()); + r.add_range (query->get_locus ()); + rust_error_at (r, "duplicate definitions with name %s", name.c_str ()); + } + +private: + OverlappingImplItemPass () : TypeCheckBase () {} + + std::map > > + impl_mappings; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_HIR_INHERENT_IMPL_ITEM_OVERLAP_H diff --git a/gcc/rust/typecheck/rust-hir-path-probe.h b/gcc/rust/typecheck/rust-hir-path-probe.h new file mode 100644 index 00000000000..bd4f91e49bf --- /dev/null +++ b/gcc/rust/typecheck/rust-hir-path-probe.h @@ -0,0 +1,540 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_HIR_PATH_PROBE_H +#define RUST_HIR_PATH_PROBE_H + +#include "rust-hir-type-check-base.h" +#include "rust-hir-full.h" +#include "rust-tyty.h" +#include "rust-substitution-mapper.h" +#include "rust-hir-type-bounds.h" + +namespace Rust { +namespace Resolver { + +struct PathProbeCandidate +{ + enum CandidateType + { + ERROR, + + ENUM_VARIANT, + + IMPL_CONST, + IMPL_TYPE_ALIAS, + IMPL_FUNC, + + TRAIT_ITEM_CONST, + TRAIT_TYPE_ALIAS, + TRAIT_FUNC, + }; + + struct EnumItemCandidate + { + const TyTy::ADTType *parent; + const TyTy::VariantDef *variant; + }; + + struct ImplItemCandidate + { + HIR::ImplItem *impl_item; + HIR::ImplBlock *parent; + }; + + struct TraitItemCandidate + { + const TraitReference *trait_ref; + const TraitItemReference *item_ref; + HIR::ImplBlock *impl; + }; + + CandidateType type; + TyTy::BaseType *ty; + Location locus; + union Candidate + { + EnumItemCandidate enum_field; + ImplItemCandidate impl; + TraitItemCandidate trait; + + Candidate (EnumItemCandidate enum_field) : enum_field (enum_field) {} + Candidate (ImplItemCandidate impl) : impl (impl) {} + Candidate (TraitItemCandidate trait) : trait (trait) {} + } item; + + PathProbeCandidate (CandidateType type, TyTy::BaseType *ty, Location locus, + EnumItemCandidate enum_field) + : type (type), ty (ty), item (enum_field) + {} + + PathProbeCandidate (CandidateType type, TyTy::BaseType *ty, Location locus, + ImplItemCandidate impl) + : type (type), ty (ty), item (impl) + {} + + PathProbeCandidate (CandidateType type, TyTy::BaseType *ty, Location locus, + TraitItemCandidate trait) + : type (type), ty (ty), item (trait) + {} + + std::string as_string () const + { + return "PathProbe candidate TODO - as_string"; + } + + bool is_enum_candidate () const { return type == ENUM_VARIANT; } + + bool is_impl_candidate () const + { + return type == IMPL_CONST || type == IMPL_TYPE_ALIAS || type == IMPL_FUNC; + } + + bool is_trait_candidate () const + { + return type == TRAIT_ITEM_CONST || type == TRAIT_TYPE_ALIAS + || type == TRAIT_FUNC; + } + + bool is_full_trait_item_candidate () const + { + return is_trait_candidate () && item.trait.impl == nullptr; + } + + static PathProbeCandidate get_error () + { + return PathProbeCandidate (ERROR, nullptr, Location (), + ImplItemCandidate{nullptr, nullptr}); + } + + bool is_error () const { return type == ERROR; } +}; + +class PathProbeType : public TypeCheckBase, public HIR::HIRImplVisitor +{ +public: + static std::vector + Probe (const TyTy::BaseType *receiver, + const HIR::PathIdentSegment &segment_name, bool probe_impls, + bool probe_bounds, bool ignore_mandatory_trait_items, + DefId specific_trait_id = UNKNOWN_DEFID) + { + PathProbeType probe (receiver, segment_name, specific_trait_id); + if (probe_impls) + { + if (receiver->get_kind () == TyTy::TypeKind::ADT) + { + const TyTy::ADTType *adt + = static_cast (receiver); + if (adt->is_enum ()) + probe.process_enum_item_for_candiates (adt); + } + + probe.process_impl_items_for_candidates (); + } + + if (!probe_bounds) + return probe.candidates; + + if (!probe.is_reciever_generic ()) + { + std::vector> probed_bounds + = TypeBoundsProbe::Probe (receiver); + for (auto &candidate : probed_bounds) + { + const TraitReference *trait_ref = candidate.first; + if (specific_trait_id != UNKNOWN_DEFID) + { + if (trait_ref->get_mappings ().get_defid () + != specific_trait_id) + continue; + } + + HIR::ImplBlock *impl = candidate.second; + probe.process_associated_trait_for_candidates ( + trait_ref, impl, ignore_mandatory_trait_items); + } + } + + for (const TyTy::TypeBoundPredicate &predicate : + receiver->get_specified_bounds ()) + { + const TraitReference *trait_ref = predicate.get (); + if (specific_trait_id != UNKNOWN_DEFID) + { + if (trait_ref->get_mappings ().get_defid () != specific_trait_id) + continue; + } + + probe.process_predicate_for_candidates (predicate, + ignore_mandatory_trait_items); + } + + return probe.candidates; + } + + void visit (HIR::TypeAlias &alias) override + { + Identifier name = alias.get_new_type_name (); + if (search.as_string ().compare (name) == 0) + { + HirId tyid = alias.get_mappings ().get_hirid (); + TyTy::BaseType *ty = nullptr; + bool ok = context->lookup_type (tyid, &ty); + rust_assert (ok); + + PathProbeCandidate::ImplItemCandidate impl_item_candidate{&alias, + current_impl}; + PathProbeCandidate candidate{ + PathProbeCandidate::CandidateType::IMPL_TYPE_ALIAS, ty, + alias.get_locus (), impl_item_candidate}; + candidates.push_back (std::move (candidate)); + } + } + + void visit (HIR::ConstantItem &constant) override + { + Identifier name = constant.get_identifier (); + if (search.as_string ().compare (name) == 0) + { + HirId tyid = constant.get_mappings ().get_hirid (); + TyTy::BaseType *ty = nullptr; + bool ok = context->lookup_type (tyid, &ty); + rust_assert (ok); + + PathProbeCandidate::ImplItemCandidate impl_item_candidate{&constant, + current_impl}; + PathProbeCandidate candidate{ + PathProbeCandidate::CandidateType::IMPL_CONST, ty, + constant.get_locus (), impl_item_candidate}; + candidates.push_back (std::move (candidate)); + } + } + + void visit (HIR::Function &function) override + { + Identifier name = function.get_function_name (); + if (search.as_string ().compare (name) == 0) + { + HirId tyid = function.get_mappings ().get_hirid (); + TyTy::BaseType *ty = nullptr; + bool ok = context->lookup_type (tyid, &ty); + rust_assert (ok); + + PathProbeCandidate::ImplItemCandidate impl_item_candidate{&function, + current_impl}; + PathProbeCandidate candidate{ + PathProbeCandidate::CandidateType::IMPL_FUNC, ty, + function.get_locus (), impl_item_candidate}; + candidates.push_back (std::move (candidate)); + } + } + +protected: + void process_enum_item_for_candiates (const TyTy::ADTType *adt) + { + if (specific_trait_id != UNKNOWN_DEFID) + return; + + TyTy::VariantDef *v; + if (!adt->lookup_variant (search.as_string (), &v)) + return; + + PathProbeCandidate::EnumItemCandidate enum_item_candidate{adt, v}; + PathProbeCandidate candidate{ + PathProbeCandidate::CandidateType::ENUM_VARIANT, receiver->clone (), + mappings->lookup_location (adt->get_ty_ref ()), enum_item_candidate}; + candidates.push_back (std::move (candidate)); + } + + void process_impl_items_for_candidates () + { + mappings->iterate_impl_items ([&] (HirId id, HIR::ImplItem *item, + HIR::ImplBlock *impl) mutable -> bool { + process_impl_item_candidate (id, item, impl); + return true; + }); + } + + void process_impl_item_candidate (HirId id, HIR::ImplItem *item, + HIR::ImplBlock *impl) + { + current_impl = impl; + HirId impl_ty_id = impl->get_type ()->get_mappings ().get_hirid (); + TyTy::BaseType *impl_block_ty = nullptr; + if (!context->lookup_type (impl_ty_id, &impl_block_ty)) + return; + + if (!receiver->can_eq (impl_block_ty, false)) + { + if (!impl_block_ty->can_eq (receiver, false)) + return; + } + + // lets visit the impl_item + item->accept_vis (*this); + } + + void + process_associated_trait_for_candidates (const TraitReference *trait_ref, + HIR::ImplBlock *impl, + bool ignore_mandatory_trait_items) + { + const TraitItemReference *trait_item_ref = nullptr; + if (!trait_ref->lookup_trait_item (search.as_string (), &trait_item_ref)) + return; + + bool trait_item_needs_implementation = !trait_item_ref->is_optional (); + if (ignore_mandatory_trait_items && trait_item_needs_implementation) + return; + + PathProbeCandidate::CandidateType candidate_type; + switch (trait_item_ref->get_trait_item_type ()) + { + case TraitItemReference::TraitItemType::FN: + candidate_type = PathProbeCandidate::CandidateType::TRAIT_FUNC; + break; + case TraitItemReference::TraitItemType::CONST: + candidate_type = PathProbeCandidate::CandidateType::TRAIT_ITEM_CONST; + break; + case TraitItemReference::TraitItemType::TYPE: + candidate_type = PathProbeCandidate::CandidateType::TRAIT_TYPE_ALIAS; + break; + + case TraitItemReference::TraitItemType::ERROR: + default: + gcc_unreachable (); + break; + } + + TyTy::BaseType *trait_item_tyty = trait_item_ref->get_tyty (); + + // we can substitute the Self with the receiver here + if (trait_item_tyty->get_kind () == TyTy::TypeKind::FNDEF) + { + TyTy::FnType *fn = static_cast (trait_item_tyty); + TyTy::SubstitutionParamMapping *param = nullptr; + for (auto ¶m_mapping : fn->get_substs ()) + { + const HIR::TypeParam &type_param + = param_mapping.get_generic_param (); + if (type_param.get_type_representation ().compare ("Self") == 0) + { + param = ¶m_mapping; + break; + } + } + rust_assert (param != nullptr); + + std::vector mappings; + mappings.push_back (TyTy::SubstitutionArg (param, receiver->clone ())); + + Location locus; // FIXME + TyTy::SubstitutionArgumentMappings args (std::move (mappings), locus); + trait_item_tyty = SubstMapperInternal::Resolve (trait_item_tyty, args); + } + + PathProbeCandidate::TraitItemCandidate trait_item_candidate{trait_ref, + trait_item_ref, + impl}; + + PathProbeCandidate candidate{candidate_type, trait_item_tyty, + trait_ref->get_locus (), trait_item_candidate}; + candidates.push_back (std::move (candidate)); + } + + void + process_predicate_for_candidates (const TyTy::TypeBoundPredicate &predicate, + bool ignore_mandatory_trait_items) + { + const TraitReference *trait_ref = predicate.get (); + + TyTy::TypeBoundPredicateItem item + = predicate.lookup_associated_item (search.as_string ()); + if (item.is_error ()) + return; + + if (ignore_mandatory_trait_items && item.needs_implementation ()) + return; + + const TraitItemReference *trait_item_ref = item.get_raw_item (); + PathProbeCandidate::CandidateType candidate_type; + switch (trait_item_ref->get_trait_item_type ()) + { + case TraitItemReference::TraitItemType::FN: + candidate_type = PathProbeCandidate::CandidateType::TRAIT_FUNC; + break; + case TraitItemReference::TraitItemType::CONST: + candidate_type = PathProbeCandidate::CandidateType::TRAIT_ITEM_CONST; + break; + case TraitItemReference::TraitItemType::TYPE: + candidate_type = PathProbeCandidate::CandidateType::TRAIT_TYPE_ALIAS; + break; + + case TraitItemReference::TraitItemType::ERROR: + default: + gcc_unreachable (); + break; + } + + TyTy::BaseType *trait_item_tyty = item.get_tyty_for_receiver (receiver); + PathProbeCandidate::TraitItemCandidate trait_item_candidate{trait_ref, + trait_item_ref, + nullptr}; + PathProbeCandidate candidate{candidate_type, trait_item_tyty, + trait_item_ref->get_locus (), + trait_item_candidate}; + candidates.push_back (std::move (candidate)); + } + +protected: + PathProbeType (const TyTy::BaseType *receiver, + const HIR::PathIdentSegment &query, DefId specific_trait_id) + : TypeCheckBase (), receiver (receiver), search (query), + current_impl (nullptr), specific_trait_id (specific_trait_id) + {} + + std::vector> + union_bounds ( + const std::vector> + a, + const std::vector> b) + const + { + std::map> mapper; + for (auto &ref : a) + { + mapper.insert ({ref.first->get_mappings ().get_defid (), ref}); + } + for (auto &ref : b) + { + mapper.insert ({ref.first->get_mappings ().get_defid (), ref}); + } + + std::vector> union_set; + for (auto it = mapper.begin (); it != mapper.end (); it++) + { + union_set.push_back ({it->second.first, it->second.second}); + } + return union_set; + } + + bool is_reciever_generic () const + { + const TyTy::BaseType *root = receiver->get_root (); + bool receiver_is_type_param = root->get_kind () == TyTy::TypeKind::PARAM; + bool receiver_is_dyn = root->get_kind () == TyTy::TypeKind::DYNAMIC; + return receiver_is_type_param || receiver_is_dyn; + } + + const TyTy::BaseType *receiver; + const HIR::PathIdentSegment &search; + std::vector candidates; + HIR::ImplBlock *current_impl; + DefId specific_trait_id; +}; + +class ReportMultipleCandidateError : private TypeCheckBase, + private HIR::HIRImplVisitor +{ +public: + static void Report (std::vector &candidates, + const HIR::PathIdentSegment &query, Location query_locus) + { + RichLocation r (query_locus); + ReportMultipleCandidateError visitor (r); + for (auto &c : candidates) + { + switch (c.type) + { + case PathProbeCandidate::CandidateType::ERROR: + case PathProbeCandidate::CandidateType::ENUM_VARIANT: + gcc_unreachable (); + break; + + case PathProbeCandidate::CandidateType::IMPL_CONST: + case PathProbeCandidate::CandidateType::IMPL_TYPE_ALIAS: + case PathProbeCandidate::CandidateType::IMPL_FUNC: + c.item.impl.impl_item->accept_vis (visitor); + break; + + case PathProbeCandidate::CandidateType::TRAIT_ITEM_CONST: + case PathProbeCandidate::CandidateType::TRAIT_TYPE_ALIAS: + case PathProbeCandidate::CandidateType::TRAIT_FUNC: + r.add_range (c.item.trait.item_ref->get_locus ()); + break; + } + } + + rust_error_at (r, "multiple applicable items in scope for: %s", + query.as_string ().c_str ()); + } + + void visit (HIR::TypeAlias &alias) override + { + r.add_range (alias.get_locus ()); + } + + void visit (HIR::ConstantItem &constant) override + { + r.add_range (constant.get_locus ()); + } + + void visit (HIR::Function &function) override + { + r.add_range (function.get_locus ()); + } + +private: + ReportMultipleCandidateError (RichLocation &r) : TypeCheckBase (), r (r) {} + + RichLocation &r; +}; + +class PathProbeImplTrait : public PathProbeType +{ +public: + static std::vector + Probe (const TyTy::BaseType *receiver, + const HIR::PathIdentSegment &segment_name, + const TraitReference *trait_reference) + { + PathProbeImplTrait probe (receiver, segment_name, trait_reference); + // iterate all impls for this trait and receiver + // then search for possible candidates using base class behaviours + probe.process_trait_impl_items_for_candidates (); + return probe.candidates; + } + +private: + void process_trait_impl_items_for_candidates (); + + PathProbeImplTrait (const TyTy::BaseType *receiver, + const HIR::PathIdentSegment &query, + const TraitReference *trait_reference) + : PathProbeType (receiver, query, UNKNOWN_DEFID), + trait_reference (trait_reference) + {} + + const TraitReference *trait_reference; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_HIR_PATH_PROBE_H diff --git a/gcc/rust/typecheck/rust-hir-trait-ref.h b/gcc/rust/typecheck/rust-hir-trait-ref.h new file mode 100644 index 00000000000..6eec461e8a5 --- /dev/null +++ b/gcc/rust/typecheck/rust-hir-trait-ref.h @@ -0,0 +1,472 @@ +// Copyright (C) 2021-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_HIR_TRAIT_REF_H +#define RUST_HIR_TRAIT_REF_H + +#include "rust-hir-full.h" +#include "rust-tyty-visitor.h" +#include "rust-hir-type-check-util.h" + +namespace Rust { +namespace Resolver { + +// Data Objects for the associated trait items in a structure we can work with +// https://doc.rust-lang.org/edition-guide/rust-2018/trait-system/associated-constants.html +class TypeCheckContext; +class TraitItemReference +{ +public: + enum TraitItemType + { + FN, + CONST, + TYPE, + ERROR + }; + + TraitItemReference (std::string identifier, bool optional, TraitItemType type, + HIR::TraitItem *hir_trait_item, TyTy::BaseType *self, + std::vector substitutions, + Location locus); + + TraitItemReference (TraitItemReference const &other); + + TraitItemReference &operator= (TraitItemReference const &other); + + static TraitItemReference error () + { + return TraitItemReference ("", false, ERROR, nullptr, nullptr, {}, + Location ()); + } + + static TraitItemReference &error_node () + { + static TraitItemReference error = TraitItemReference::error (); + return error; + } + + bool is_error () const { return type == ERROR; } + + std::string as_string () const + { + return "(" + trait_item_type_as_string (type) + " " + identifier + " " + + ")"; + } + + static std::string trait_item_type_as_string (TraitItemType ty) + { + switch (ty) + { + case FN: + return "FN"; + case CONST: + return "CONST"; + case TYPE: + return "TYPE"; + case ERROR: + return "ERROR"; + } + return "ERROR"; + } + + bool is_optional () const { return optional_flag; } + + std::string get_identifier () const { return identifier; } + + TraitItemType get_trait_item_type () const { return type; } + + HIR::TraitItem *get_hir_trait_item () const { return hir_trait_item; } + + Location get_locus () const { return locus; } + + const Analysis::NodeMapping get_mappings () const + { + return hir_trait_item->get_mappings (); + } + + TyTy::BaseType *get_tyty () const + { + rust_assert (hir_trait_item != nullptr); + + switch (type) + { + case CONST: + return get_type_from_constant ( + static_cast (*hir_trait_item)); + break; + + case TYPE: + return get_type_from_typealias ( + static_cast (*hir_trait_item)); + + case FN: + return get_type_from_fn ( + static_cast (*hir_trait_item)); + break; + + default: + return get_error (); + } + + gcc_unreachable (); + return get_error (); + } + + Analysis::NodeMapping get_parent_trait_mappings () const; + + // this is called when the trait is completed resolution and gives the items a + // chance to run their specific type resolution passes. If we call their + // resolution on construction it can lead to a case where the trait being + // resolved recursively trying to resolve the trait itself infinitely since + // the trait will not be stored in its own map yet + void on_resolved (); + + void associated_type_set (TyTy::BaseType *ty) const; + + void associated_type_reset () const; + + bool is_object_safe () const; + +private: + TyTy::ErrorType *get_error () const + { + return new TyTy::ErrorType (get_mappings ().get_hirid ()); + } + + TyTy::BaseType *get_type_from_typealias (/*const*/ + HIR::TraitItemType &type) const; + + TyTy::BaseType * + get_type_from_constant (/*const*/ HIR::TraitItemConst &constant) const; + + TyTy::BaseType *get_type_from_fn (/*const*/ HIR::TraitItemFunc &fn) const; + + bool is_item_resolved () const; + void resolve_item (HIR::TraitItemType &type); + void resolve_item (HIR::TraitItemConst &constant); + void resolve_item (HIR::TraitItemFunc &func); + + std::string identifier; + bool optional_flag; + TraitItemType type; + HIR::TraitItem *hir_trait_item; + std::vector inherited_substitutions; + Location locus; + + TyTy::BaseType + *self; // this is the implict Self TypeParam required for methods + Resolver::TypeCheckContext *context; +}; + +// this wraps up the HIR::Trait so we can do analysis on it + +class TraitReference +{ +public: + TraitReference (const HIR::Trait *hir_trait_ref, + std::vector item_refs, + std::vector super_traits, + std::vector substs) + : hir_trait_ref (hir_trait_ref), item_refs (item_refs), + super_traits (super_traits) + { + trait_substs.clear (); + trait_substs.reserve (substs.size ()); + for (const auto &p : substs) + trait_substs.push_back (p.clone ()); + } + + TraitReference (TraitReference const &other) + : hir_trait_ref (other.hir_trait_ref), item_refs (other.item_refs), + super_traits (other.super_traits) + { + trait_substs.clear (); + trait_substs.reserve (other.trait_substs.size ()); + for (const auto &p : other.trait_substs) + trait_substs.push_back (p.clone ()); + } + + TraitReference &operator= (TraitReference const &other) + { + hir_trait_ref = other.hir_trait_ref; + item_refs = other.item_refs; + super_traits = other.super_traits; + + trait_substs.clear (); + trait_substs.reserve (other.trait_substs.size ()); + for (const auto &p : other.trait_substs) + trait_substs.push_back (p.clone ()); + + return *this; + } + + TraitReference (TraitReference &&other) = default; + TraitReference &operator= (TraitReference &&other) = default; + + static TraitReference error () + { + return TraitReference (nullptr, {}, {}, {}); + } + + bool is_error () const { return hir_trait_ref == nullptr; } + + static TraitReference &error_node () + { + static TraitReference trait_error_node = TraitReference::error (); + return trait_error_node; + } + + Location get_locus () const { return hir_trait_ref->get_locus (); } + + std::string get_name () const + { + rust_assert (!is_error ()); + return hir_trait_ref->get_name (); + } + + std::string as_string () const + { + if (is_error ()) + return ""; + + std::string item_buf; + for (auto &item : item_refs) + { + item_buf += item.as_string () + ", "; + } + return "HIR Trait: " + get_name () + "->" + + hir_trait_ref->get_mappings ().as_string () + " [" + item_buf + + "]"; + } + + const HIR::Trait *get_hir_trait_ref () const { return hir_trait_ref; } + + const Analysis::NodeMapping &get_mappings () const + { + return hir_trait_ref->get_mappings (); + } + + DefId get_defid () const { return get_mappings ().get_defid (); } + + bool lookup_hir_trait_item (const HIR::TraitItem &item, + TraitItemReference **ref) + { + return lookup_trait_item (item.trait_identifier (), ref); + } + + bool lookup_trait_item (const std::string &ident, TraitItemReference **ref) + { + for (auto &item : item_refs) + { + if (ident.compare (item.get_identifier ()) == 0) + { + *ref = &item; + return true; + } + } + return false; + } + + bool lookup_trait_item_by_type (const std::string &ident, + TraitItemReference::TraitItemType type, + TraitItemReference **ref) + { + for (auto &item : item_refs) + { + if (item.get_trait_item_type () != type) + continue; + + if (ident.compare (item.get_identifier ()) == 0) + { + *ref = &item; + return true; + } + } + return false; + } + + bool lookup_trait_item_by_type (const std::string &ident, + TraitItemReference::TraitItemType type, + const TraitItemReference **ref) const + { + for (auto &item : item_refs) + { + if (item.get_trait_item_type () != type) + continue; + + if (ident.compare (item.get_identifier ()) == 0) + { + *ref = &item; + return true; + } + } + return false; + } + + bool lookup_hir_trait_item (const HIR::TraitItem &item, + const TraitItemReference **ref) const + { + return lookup_trait_item (item.trait_identifier (), ref); + } + + bool lookup_trait_item (const std::string &ident, + const TraitItemReference **ref) const + { + for (auto &item : item_refs) + { + if (ident.compare (item.get_identifier ()) == 0) + { + *ref = &item; + return true; + } + } + return false; + } + + const TraitItemReference * + lookup_trait_item (const std::string &ident, + TraitItemReference::TraitItemType type) const + { + for (auto &item : item_refs) + { + if (item.get_trait_item_type () != type) + continue; + + if (ident.compare (item.get_identifier ()) == 0) + return &item; + } + return &TraitItemReference::error_node (); + } + + size_t size () const { return item_refs.size (); } + + const std::vector &get_trait_items () const + { + return item_refs; + } + + void on_resolved () + { + for (auto &item : item_refs) + { + item.on_resolved (); + } + } + + void clear_associated_types () + { + for (auto &item : item_refs) + { + bool is_assoc_type = item.get_trait_item_type () + == TraitItemReference::TraitItemType::TYPE; + if (is_assoc_type) + item.associated_type_reset (); + } + } + + bool is_equal (const TraitReference &other) const + { + DefId this_id = get_mappings ().get_defid (); + DefId other_id = other.get_mappings ().get_defid (); + return this_id == other_id; + } + + const std::vector get_super_traits () const + { + return super_traits; + } + + bool is_object_safe (bool emit_error, Location locus) const + { + // https: // doc.rust-lang.org/reference/items/traits.html#object-safety + std::vector non_object_super_traits; + for (auto &item : super_traits) + { + if (!item->is_object_safe (false, Location ())) + non_object_super_traits.push_back (item); + } + + std::vector non_object_safe_items; + for (auto &item : get_trait_items ()) + { + if (!item.is_object_safe ()) + non_object_safe_items.push_back (&item); + } + + bool is_safe + = non_object_super_traits.empty () && non_object_safe_items.empty (); + if (emit_error && !is_safe) + { + RichLocation r (locus); + for (auto &item : non_object_super_traits) + r.add_range (item->get_locus ()); + for (auto &item : non_object_safe_items) + r.add_range (item->get_locus ()); + + rust_error_at (r, "trait bound is not object safe"); + } + + return is_safe; + } + + bool trait_has_generics () const { return !trait_substs.empty (); } + + std::vector get_trait_substs () const + { + return trait_substs; + } + +private: + const HIR::Trait *hir_trait_ref; + std::vector item_refs; + std::vector super_traits; + std::vector trait_substs; +}; + +class AssociatedImplTrait +{ +public: + AssociatedImplTrait (TraitReference *trait, HIR::ImplBlock *impl, + TyTy::BaseType *self, + Resolver::TypeCheckContext *context) + : trait (trait), impl (impl), self (self), context (context) + {} + + TraitReference *get_trait () { return trait; } + + HIR::ImplBlock *get_impl_block () { return impl; } + + TyTy::BaseType *get_self () { return self; } + + void setup_associated_types (const TyTy::BaseType *self, + const TyTy::TypeBoundPredicate &bound); + + void reset_associated_types (); + +private: + TraitReference *trait; + HIR::ImplBlock *impl; + TyTy::BaseType *self; + Resolver::TypeCheckContext *context; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_HIR_TRAIT_REF_H diff --git a/gcc/rust/typecheck/rust-hir-type-bounds.h b/gcc/rust/typecheck/rust-hir-type-bounds.h new file mode 100644 index 00000000000..44400efbbf7 --- /dev/null +++ b/gcc/rust/typecheck/rust-hir-type-bounds.h @@ -0,0 +1,77 @@ +// Copyright (C) 2021-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_HIR_TYPE_BOUNDS_H +#define RUST_HIR_TYPE_BOUNDS_H + +#include "rust-hir-type-check-base.h" +#include "rust-hir-full.h" +#include "rust-tyty.h" + +namespace Rust { +namespace Resolver { + +class TypeBoundsProbe : public TypeCheckBase +{ +public: + static std::vector> + Probe (const TyTy::BaseType *receiver) + { + TypeBoundsProbe probe (receiver); + probe.scan (); + return probe.trait_references; + } + + static bool is_bound_satisfied_for_type (TyTy::BaseType *receiver, + TraitReference *ref) + { + for (auto &bound : receiver->get_specified_bounds ()) + { + const TraitReference *b = bound.get (); + if (b->is_equal (*ref)) + return true; + } + + std::vector> bounds + = Probe (receiver); + for (auto &bound : bounds) + { + const TraitReference *b = bound.first; + if (b->is_equal (*ref)) + return true; + } + + return false; + } + +private: + void scan (); + +private: + TypeBoundsProbe (const TyTy::BaseType *receiver) + : TypeCheckBase (), receiver (receiver) + {} + + const TyTy::BaseType *receiver; + std::vector> trait_references; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_HIR_TYPE_BOUNDS_H diff --git a/gcc/rust/typecheck/rust-substitution-mapper.cc b/gcc/rust/typecheck/rust-substitution-mapper.cc new file mode 100644 index 00000000000..f80368a0339 --- /dev/null +++ b/gcc/rust/typecheck/rust-substitution-mapper.cc @@ -0,0 +1,77 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-substitution-mapper.h" +#include "rust-hir-type-check.h" + +namespace Rust { +namespace Resolver { + +TyTy::BaseType * +SubstMapperInternal::Resolve (TyTy::BaseType *base, + TyTy::SubstitutionArgumentMappings &mappings) +{ + auto context = TypeCheckContext::get (); + + SubstMapperInternal mapper (base->get_ref (), mappings); + base->accept_vis (mapper); + rust_assert (mapper.resolved != nullptr); + + // insert these new implict types into the context + TyTy::BaseType *unused = nullptr; + bool is_ty_available + = context->lookup_type (mapper.resolved->get_ty_ref (), &unused); + if (!is_ty_available) + { + context->insert_type ( + Analysis::NodeMapping (0, 0, mapper.resolved->get_ty_ref (), 0), + mapper.resolved); + } + bool is_ref_available + = context->lookup_type (mapper.resolved->get_ref (), &unused); + if (!is_ref_available) + { + context->insert_type (Analysis::NodeMapping (0, 0, + mapper.resolved->get_ref (), + 0), + mapper.resolved); + } + + return mapper.resolved; +} + +bool +SubstMapperInternal::mappings_are_bound ( + TyTy::BaseType *tyseg, TyTy::SubstitutionArgumentMappings &mappings) +{ + if (tyseg->get_kind () == TyTy::TypeKind::ADT) + { + TyTy::ADTType *adt = static_cast (tyseg); + return adt->are_mappings_bound (mappings); + } + else if (tyseg->get_kind () == TyTy::TypeKind::FNDEF) + { + TyTy::FnType *fn = static_cast (tyseg); + return fn->are_mappings_bound (mappings); + } + + return false; +} + +} // namespace Resolver +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-substitution-mapper.h b/gcc/rust/typecheck/rust-substitution-mapper.h new file mode 100644 index 00000000000..028e10c0efe --- /dev/null +++ b/gcc/rust/typecheck/rust-substitution-mapper.h @@ -0,0 +1,394 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_SUBSTITUTION_MAPPER_H +#define RUST_SUBSTITUTION_MAPPER_H + +#include "rust-tyty.h" +#include "rust-tyty-visitor.h" + +namespace Rust { +namespace Resolver { + +class SubstMapper : public TyTy::TyVisitor +{ +public: + static TyTy::BaseType *Resolve (TyTy::BaseType *base, Location locus, + HIR::GenericArgs *generics = nullptr) + { + SubstMapper mapper (base->get_ref (), generics, locus); + base->accept_vis (mapper); + rust_assert (mapper.resolved != nullptr); + return mapper.resolved; + } + + static TyTy::BaseType *InferSubst (TyTy::BaseType *base, Location locus) + { + return SubstMapper::Resolve (base, locus, nullptr); + } + + bool have_generic_args () const { return generics != nullptr; } + + void visit (TyTy::FnType &type) override + { + TyTy::FnType *concrete = nullptr; + if (!have_generic_args ()) + { + TyTy::BaseType *substs = type.infer_substitions (locus); + rust_assert (substs->get_kind () == TyTy::TypeKind::FNDEF); + concrete = static_cast (substs); + } + else + { + TyTy::SubstitutionArgumentMappings mappings + = type.get_mappings_from_generic_args (*generics); + if (mappings.is_error ()) + return; + + concrete = type.handle_substitions (mappings); + } + + if (concrete != nullptr) + resolved = concrete; + } + + void visit (TyTy::ADTType &type) override + { + TyTy::ADTType *concrete = nullptr; + if (!have_generic_args ()) + { + TyTy::BaseType *substs = type.infer_substitions (locus); + rust_assert (substs->get_kind () == TyTy::TypeKind::ADT); + concrete = static_cast (substs); + } + else + { + TyTy::SubstitutionArgumentMappings mappings + = type.get_mappings_from_generic_args (*generics); + if (mappings.is_error ()) + return; + + concrete = type.handle_substitions (mappings); + } + + if (concrete != nullptr) + resolved = concrete; + } + + void visit (TyTy::PlaceholderType &type) override + { + rust_assert (type.can_resolve ()); + resolved = SubstMapper::Resolve (type.resolve (), locus, generics); + } + + void visit (TyTy::ProjectionType &type) override + { + TyTy::ProjectionType *concrete = nullptr; + if (!have_generic_args ()) + { + TyTy::BaseType *substs = type.infer_substitions (locus); + rust_assert (substs->get_kind () == TyTy::TypeKind::ADT); + concrete = static_cast (substs); + } + else + { + TyTy::SubstitutionArgumentMappings mappings + = type.get_mappings_from_generic_args (*generics); + if (mappings.is_error ()) + return; + + concrete = type.handle_substitions (mappings); + } + + if (concrete != nullptr) + resolved = concrete; + } + + // nothing to do for these + void visit (TyTy::InferType &) override { gcc_unreachable (); } + void visit (TyTy::TupleType &) override { gcc_unreachable (); } + void visit (TyTy::FnPtr &) override { gcc_unreachable (); } + void visit (TyTy::ArrayType &) override { gcc_unreachable (); } + void visit (TyTy::SliceType &) override { gcc_unreachable (); } + void visit (TyTy::BoolType &) override { gcc_unreachable (); } + void visit (TyTy::IntType &) override { gcc_unreachable (); } + void visit (TyTy::UintType &) override { gcc_unreachable (); } + void visit (TyTy::FloatType &) override { gcc_unreachable (); } + void visit (TyTy::USizeType &) override { gcc_unreachable (); } + void visit (TyTy::ISizeType &) override { gcc_unreachable (); } + void visit (TyTy::ErrorType &) override { gcc_unreachable (); } + void visit (TyTy::CharType &) override { gcc_unreachable (); } + void visit (TyTy::ReferenceType &) override { gcc_unreachable (); } + void visit (TyTy::PointerType &) override { gcc_unreachable (); } + void visit (TyTy::ParamType &) override { gcc_unreachable (); } + void visit (TyTy::StrType &) override { gcc_unreachable (); } + void visit (TyTy::NeverType &) override { gcc_unreachable (); } + void visit (TyTy::DynamicObjectType &) override { gcc_unreachable (); } + void visit (TyTy::ClosureType &) override { gcc_unreachable (); } + +private: + SubstMapper (HirId ref, HIR::GenericArgs *generics, Location locus) + : resolved (new TyTy::ErrorType (ref)), generics (generics), locus (locus) + {} + + TyTy::BaseType *resolved; + HIR::GenericArgs *generics; + Location locus; +}; + +class SubstMapperInternal : public TyTy::TyVisitor +{ +public: + static TyTy::BaseType *Resolve (TyTy::BaseType *base, + TyTy::SubstitutionArgumentMappings &mappings); + + static bool mappings_are_bound (TyTy::BaseType *ty, + TyTy::SubstitutionArgumentMappings &mappings); + + void visit (TyTy::FnType &type) override + { + TyTy::SubstitutionArgumentMappings adjusted + = type.adjust_mappings_for_this (mappings); + if (adjusted.is_error ()) + return; + + TyTy::BaseType *concrete = type.handle_substitions (adjusted); + if (concrete != nullptr) + resolved = concrete; + } + + void visit (TyTy::ADTType &type) override + { + TyTy::SubstitutionArgumentMappings adjusted + = type.adjust_mappings_for_this (mappings); + if (adjusted.is_error ()) + return; + + TyTy::BaseType *concrete = type.handle_substitions (adjusted); + if (concrete != nullptr) + resolved = concrete; + } + + // these don't support generic arguments but might contain a type param + void visit (TyTy::TupleType &type) override + { + resolved = type.handle_substitions (mappings); + } + + void visit (TyTy::ReferenceType &type) override + { + resolved = type.handle_substitions (mappings); + } + + void visit (TyTy::PointerType &type) override + { + resolved = type.handle_substitions (mappings); + } + + void visit (TyTy::ParamType &type) override + { + resolved = type.handle_substitions (mappings); + } + + void visit (TyTy::PlaceholderType &type) override + { + rust_assert (type.can_resolve ()); + if (mappings.trait_item_mode ()) + { + resolved = type.resolve (); + } + else + { + resolved = SubstMapperInternal::Resolve (type.resolve (), mappings); + } + } + + void visit (TyTy::ProjectionType &type) override + { + resolved = type.handle_substitions (mappings); + } + + void visit (TyTy::ClosureType &type) override + { + resolved = type.handle_substitions (mappings); + } + + void visit (TyTy::ArrayType &type) override + { + resolved = type.handle_substitions (mappings); + } + + void visit (TyTy::SliceType &type) override + { + resolved = type.handle_substitions (mappings); + } + + // nothing to do for these + void visit (TyTy::InferType &type) override { resolved = type.clone (); } + void visit (TyTy::FnPtr &type) override { resolved = type.clone (); } + void visit (TyTy::BoolType &type) override { resolved = type.clone (); } + void visit (TyTy::IntType &type) override { resolved = type.clone (); } + void visit (TyTy::UintType &type) override { resolved = type.clone (); } + void visit (TyTy::FloatType &type) override { resolved = type.clone (); } + void visit (TyTy::USizeType &type) override { resolved = type.clone (); } + void visit (TyTy::ISizeType &type) override { resolved = type.clone (); } + void visit (TyTy::ErrorType &type) override { resolved = type.clone (); } + void visit (TyTy::CharType &type) override { resolved = type.clone (); } + void visit (TyTy::StrType &type) override { resolved = type.clone (); } + void visit (TyTy::NeverType &type) override { resolved = type.clone (); } + void visit (TyTy::DynamicObjectType &type) override + { + resolved = type.clone (); + } + +private: + SubstMapperInternal (HirId ref, TyTy::SubstitutionArgumentMappings &mappings) + : resolved (new TyTy::ErrorType (ref)), mappings (mappings) + {} + + TyTy::BaseType *resolved; + TyTy::SubstitutionArgumentMappings &mappings; +}; + +class SubstMapperFromExisting : public TyTy::TyVisitor +{ +public: + static TyTy::BaseType *Resolve (TyTy::BaseType *concrete, + TyTy::BaseType *receiver) + { + rust_assert (concrete->get_kind () == receiver->get_kind ()); + + SubstMapperFromExisting mapper (concrete, receiver); + concrete->accept_vis (mapper); + return mapper.resolved; + } + + void visit (TyTy::FnType &type) override + { + rust_assert (type.was_substituted ()); + + TyTy::FnType *to_sub = static_cast (receiver); + resolved = to_sub->handle_substitions (type.get_substitution_arguments ()); + } + + void visit (TyTy::ADTType &type) override + { + rust_assert (type.was_substituted ()); + + TyTy::ADTType *to_sub = static_cast (receiver); + resolved = to_sub->handle_substitions (type.get_substitution_arguments ()); + } + + void visit (TyTy::ClosureType &type) override + { + rust_assert (type.was_substituted ()); + + TyTy::ClosureType *to_sub = static_cast (receiver); + resolved = to_sub->handle_substitions (type.get_substitution_arguments ()); + } + + void visit (TyTy::InferType &) override { gcc_unreachable (); } + void visit (TyTy::TupleType &) override { gcc_unreachable (); } + void visit (TyTy::FnPtr &) override { gcc_unreachable (); } + void visit (TyTy::ArrayType &) override { gcc_unreachable (); } + void visit (TyTy::SliceType &) override { gcc_unreachable (); } + void visit (TyTy::BoolType &) override { gcc_unreachable (); } + void visit (TyTy::IntType &) override { gcc_unreachable (); } + void visit (TyTy::UintType &) override { gcc_unreachable (); } + void visit (TyTy::FloatType &) override { gcc_unreachable (); } + void visit (TyTy::USizeType &) override { gcc_unreachable (); } + void visit (TyTy::ISizeType &) override { gcc_unreachable (); } + void visit (TyTy::ErrorType &) override { gcc_unreachable (); } + void visit (TyTy::CharType &) override { gcc_unreachable (); } + void visit (TyTy::ReferenceType &) override { gcc_unreachable (); } + void visit (TyTy::PointerType &) override { gcc_unreachable (); } + void visit (TyTy::ParamType &) override { gcc_unreachable (); } + void visit (TyTy::StrType &) override { gcc_unreachable (); } + void visit (TyTy::NeverType &) override { gcc_unreachable (); } + void visit (TyTy::PlaceholderType &) override { gcc_unreachable (); } + void visit (TyTy::ProjectionType &) override { gcc_unreachable (); } + void visit (TyTy::DynamicObjectType &) override { gcc_unreachable (); } + +private: + SubstMapperFromExisting (TyTy::BaseType *concrete, TyTy::BaseType *receiver) + : concrete (concrete), receiver (receiver), resolved (nullptr) + {} + + TyTy::BaseType *concrete; + TyTy::BaseType *receiver; + + TyTy::BaseType *resolved; +}; + +class GetUsedSubstArgs : public TyTy::TyConstVisitor +{ +public: + static TyTy::SubstitutionArgumentMappings From (const TyTy::BaseType *from) + { + GetUsedSubstArgs mapper; + from->accept_vis (mapper); + return mapper.args; + } + + void visit (const TyTy::FnType &type) override + { + args = type.get_substitution_arguments (); + } + + void visit (const TyTy::ADTType &type) override + { + args = type.get_substitution_arguments (); + } + + void visit (const TyTy::ClosureType &type) override + { + args = type.get_substitution_arguments (); + } + + void visit (const TyTy::InferType &) override {} + void visit (const TyTy::TupleType &) override {} + void visit (const TyTy::FnPtr &) override {} + void visit (const TyTy::ArrayType &) override {} + void visit (const TyTy::SliceType &) override {} + void visit (const TyTy::BoolType &) override {} + void visit (const TyTy::IntType &) override {} + void visit (const TyTy::UintType &) override {} + void visit (const TyTy::FloatType &) override {} + void visit (const TyTy::USizeType &) override {} + void visit (const TyTy::ISizeType &) override {} + void visit (const TyTy::ErrorType &) override {} + void visit (const TyTy::CharType &) override {} + void visit (const TyTy::ReferenceType &) override {} + void visit (const TyTy::PointerType &) override {} + void visit (const TyTy::ParamType &) override {} + void visit (const TyTy::StrType &) override {} + void visit (const TyTy::NeverType &) override {} + void visit (const TyTy::PlaceholderType &) override {} + void visit (const TyTy::ProjectionType &) override {} + void visit (const TyTy::DynamicObjectType &) override {} + +private: + GetUsedSubstArgs () : args (TyTy::SubstitutionArgumentMappings::error ()) {} + + TyTy::SubstitutionArgumentMappings args; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_SUBSTITUTION_MAPPER_H diff --git a/gcc/rust/typecheck/rust-tycheck-dump.h b/gcc/rust/typecheck/rust-tycheck-dump.h new file mode 100644 index 00000000000..ccf0f625e4b --- /dev/null +++ b/gcc/rust/typecheck/rust-tycheck-dump.h @@ -0,0 +1,239 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_TYCHECK_DUMP +#define RUST_TYCHECK_DUMP + +#include "rust-hir-type-check-base.h" +#include "rust-hir-full.h" + +namespace Rust { +namespace Resolver { + +class TypeResolverDump : private TypeCheckBase, private HIR::HIRFullVisitorBase +{ + using HIR::HIRFullVisitorBase::visit; + +public: + static void go (HIR::Crate &crate, std::ofstream &out) + { + TypeResolverDump dumper; + for (auto &item : crate.items) + { + item->accept_vis (dumper); + dumper.dump += "\n"; + } + + out << dumper.dump; + } + + void visit (HIR::StructStruct &struct_decl) override + { + dump += indent () + "struct " + type_string (struct_decl.get_mappings ()) + + "\n"; + } + + void visit (HIR::Union &union_decl) override + { + dump + += indent () + "union " + type_string (union_decl.get_mappings ()) + "\n"; + } + + void visit (HIR::TupleStruct &struct_decl) override + { + dump += indent () + "struct" + type_string (struct_decl.get_mappings ()) + + "\n"; + } + + void visit (HIR::ImplBlock &impl_block) override + { + dump += indent () + "impl " + + type_string (impl_block.get_type ()->get_mappings ()) + " {\n"; + indentation_level++; + + for (auto &impl_item : impl_block.get_impl_items ()) + { + impl_item->accept_vis (*this); + dump += "\n"; + } + + indentation_level--; + dump += indent () + "}\n"; + } + + void visit (HIR::ConstantItem &constant) override + { + dump += indent () + "constant " + constant.get_identifier () + ":" + + type_string (constant.get_mappings ()) + " = "; + constant.get_expr ()->accept_vis (*this); + dump += ";\n"; + } + + void visit (HIR::Function &function) override + { + dump += indent () + "fn " + function.get_function_name () + " " + + type_string (function.get_mappings ()) + "\n"; + dump += indent () + "{\n"; + + HIR::BlockExpr *function_body = function.get_definition ().get (); + function_body->accept_vis (*this); + + dump += indent () + "}\n"; + } + + void visit (HIR::BlockExpr &expr) override + { + dump += "{\n"; + indentation_level++; + + for (auto &s : expr.get_statements ()) + { + dump += indent (); + s->accept_vis (*this); + dump += ";\n"; + } + + if (expr.has_expr ()) + { + dump += indent (); + expr.expr->accept_vis (*this); + dump += ";\n"; + } + + indentation_level--; + dump += "}\n"; + } + + void visit (HIR::UnsafeBlockExpr &expr) override + { + dump += "unsafe "; + expr.get_block_expr ()->accept_vis (*this); + } + + void visit (HIR::LetStmt &stmt) override + { + dump += "let " + stmt.get_pattern ()->as_string () + ":" + + type_string (stmt.get_pattern ()->get_pattern_mappings ()); + if (stmt.has_init_expr ()) + { + dump += " = "; + stmt.get_init_expr ()->accept_vis (*this); + } + } + + void visit (HIR::ExprStmtWithBlock &stmt) override + { + stmt.get_expr ()->accept_vis (*this); + } + + void visit (HIR::ExprStmtWithoutBlock &stmt) override + { + stmt.get_expr ()->accept_vis (*this); + } + + void visit (HIR::AssignmentExpr &expr) override + { + expr.get_lhs ()->accept_vis (*this); + dump += " = "; + expr.get_rhs ()->accept_vis (*this); + } + + void visit (HIR::LiteralExpr &expr) override + { + dump += expr.get_literal ().as_string () + ":" + + type_string (expr.get_mappings ()); + } + + void visit (HIR::ArrayExpr &expr) override + { + dump += type_string (expr.get_mappings ()) + ":["; + + HIR::ArrayElems *elements = expr.get_internal_elements (); + elements->accept_vis (*this); + + dump += "]"; + } + + void visit (HIR::ArrayElemsValues &elems) override + { + for (auto &elem : elems.get_values ()) + { + elem->accept_vis (*this); + dump += ","; + } + } + + void visit (HIR::GroupedExpr &expr) override + { + HIR::Expr *paren_expr = expr.get_expr_in_parens ().get (); + dump += "("; + paren_expr->accept_vis (*this); + dump += ")"; + } + + void visit (HIR::PathInExpression &expr) override + { + dump += type_string (expr.get_mappings ()); + } + + void visit (HIR::StructExprStructFields &expr) override + { + dump += "ctor: " + type_string (expr.get_mappings ()); + } + +protected: + std::string type_string (const Analysis::NodeMapping &mappings) + { + TyTy::BaseType *lookup = nullptr; + if (!context->lookup_type (mappings.get_hirid (), &lookup)) + return ""; + + std::string buf = "["; + for (auto &ref : lookup->get_combined_refs ()) + { + buf += std::to_string (ref); + buf += ", "; + } + buf += "]"; + + std::string repr = lookup->as_string (); + return "<" + repr + " HIRID: " + std::to_string (mappings.get_hirid ()) + + " RF:" + std::to_string (lookup->get_ref ()) + " TF:" + + std::to_string (lookup->get_ty_ref ()) + +" - " + buf + ">"; + } + + std::string indent () + { + std::string buf; + for (size_t i = 0; i < indentation_level; ++i) + buf += " "; + + return buf; + } + +private: + TypeResolverDump () : TypeCheckBase (), indentation_level (0) {} + + std::string dump; + size_t indentation_level; +}; + +} // namespace Resolver +} // namespace Rust + +#endif // RUST_TYCHECK_DUMP diff --git a/gcc/rust/typecheck/rust-tyctx.cc b/gcc/rust/typecheck/rust-tyctx.cc new file mode 100644 index 00000000000..d8a49e8b9ea --- /dev/null +++ b/gcc/rust/typecheck/rust-tyctx.cc @@ -0,0 +1,155 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-hir-type-check.h" + +namespace Rust { +namespace Resolver { + +TypeCheckContext * +TypeCheckContext::get () +{ + static TypeCheckContext *instance; + if (instance == nullptr) + instance = new TypeCheckContext (); + + return instance; +} + +TypeCheckContext::TypeCheckContext () {} + +TypeCheckContext::~TypeCheckContext () {} + +bool +TypeCheckContext::lookup_builtin (NodeId id, TyTy::BaseType **type) +{ + auto ref_it = node_id_refs.find (id); + if (ref_it == node_id_refs.end ()) + return false; + + auto it = resolved.find (ref_it->second); + if (it == resolved.end ()) + return false; + + *type = it->second; + return true; +} + +bool +TypeCheckContext::lookup_builtin (std::string name, TyTy::BaseType **type) +{ + for (auto &builtin : builtins) + { + if (name.compare (builtin->as_string ()) == 0) + { + *type = builtin.get (); + return true; + } + } + return false; +} + +void +TypeCheckContext::insert_builtin (HirId id, NodeId ref, TyTy::BaseType *type) +{ + node_id_refs[ref] = id; + resolved[id] = type; + builtins.push_back (std::unique_ptr (type)); +} + +void +TypeCheckContext::insert_type (const Analysis::NodeMapping &mappings, + TyTy::BaseType *type) +{ + rust_assert (type != nullptr); + NodeId ref = mappings.get_nodeid (); + HirId id = mappings.get_hirid (); + node_id_refs[ref] = id; + resolved[id] = type; +} + +void +TypeCheckContext::insert_implicit_type (TyTy::BaseType *type) +{ + rust_assert (type != nullptr); + resolved[type->get_ref ()] = type; +} + +void +TypeCheckContext::insert_implicit_type (HirId id, TyTy::BaseType *type) +{ + rust_assert (type != nullptr); + resolved[id] = type; +} + +bool +TypeCheckContext::lookup_type (HirId id, TyTy::BaseType **type) const +{ + auto it = resolved.find (id); + if (it == resolved.end ()) + return false; + + *type = it->second; + return true; +} + +void +TypeCheckContext::insert_type_by_node_id (NodeId ref, HirId id) +{ + rust_assert (node_id_refs.find (ref) == node_id_refs.end ()); + node_id_refs[ref] = id; +} + +bool +TypeCheckContext::lookup_type_by_node_id (NodeId ref, HirId *id) +{ + auto it = node_id_refs.find (ref); + if (it == node_id_refs.end ()) + return false; + + *id = it->second; + return true; +} + +TyTy::BaseType * +TypeCheckContext::peek_return_type () +{ + return return_type_stack.back ().second; +} + +void +TypeCheckContext::push_return_type (TypeCheckContextItem item, + TyTy::BaseType *return_type) +{ + return_type_stack.push_back ({std::move (item), return_type}); +} + +void +TypeCheckContext::pop_return_type () +{ + return_type_stack.pop_back (); +} + +TypeCheckContextItem & +TypeCheckContext::peek_context () +{ + return return_type_stack.back ().first; +} + +} // namespace Resolver +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-tyty-bounds.cc b/gcc/rust/typecheck/rust-tyty-bounds.cc new file mode 100644 index 00000000000..7a1562ab544 --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-bounds.cc @@ -0,0 +1,462 @@ +// Copyright (C) 2021-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-hir-type-bounds.h" +#include "rust-hir-trait-resolve.h" + +namespace Rust { +namespace Resolver { + +void +TypeBoundsProbe::scan () +{ + std::vector> + possible_trait_paths; + mappings->iterate_impl_blocks ( + [&] (HirId id, HIR::ImplBlock *impl) mutable -> bool { + // we are filtering for trait-impl-blocks + if (!impl->has_trait_ref ()) + return true; + + TyTy::BaseType *impl_type = nullptr; + bool ok + = context->lookup_type (impl->get_type ()->get_mappings ().get_hirid (), + &impl_type); + if (!ok) + return true; + + if (!receiver->can_eq (impl_type, false)) + { + if (!impl_type->can_eq (receiver, false)) + return true; + } + + possible_trait_paths.push_back ({impl->get_trait_ref ().get (), impl}); + return true; + }); + + for (auto &path : possible_trait_paths) + { + HIR::TypePath *trait_path = path.first; + TraitReference *trait_ref = TraitResolver::Resolve (*trait_path); + + if (!trait_ref->is_error ()) + trait_references.push_back ({trait_ref, path.second}); + } +} + +TraitReference * +TypeCheckBase::resolve_trait_path (HIR::TypePath &path) +{ + return TraitResolver::Resolve (path); +} + +TyTy::TypeBoundPredicate +TypeCheckBase::get_predicate_from_bound (HIR::TypePath &type_path) +{ + TraitReference *trait = resolve_trait_path (type_path); + if (trait->is_error ()) + return TyTy::TypeBoundPredicate::error (); + + TyTy::TypeBoundPredicate predicate (*trait, type_path.get_locus ()); + HIR::GenericArgs args + = HIR::GenericArgs::create_empty (type_path.get_locus ()); + + auto &final_seg = type_path.get_final_segment (); + if (final_seg->is_generic_segment ()) + { + auto final_generic_seg + = static_cast (final_seg.get ()); + if (final_generic_seg->has_generic_args ()) + { + args = final_generic_seg->get_generic_args (); + } + } + + if (predicate.requires_generic_args ()) + { + // this is applying generic arguments to a trait reference + predicate.apply_generic_arguments (&args); + } + + return predicate; +} + +} // namespace Resolver + +namespace TyTy { + +TypeBoundPredicate::TypeBoundPredicate ( + const Resolver::TraitReference &trait_reference, Location locus) + : SubstitutionRef ({}, SubstitutionArgumentMappings::error ()), + reference (trait_reference.get_mappings ().get_defid ()), locus (locus), + error_flag (false) +{ + substitutions.clear (); + for (const auto &p : trait_reference.get_trait_substs ()) + substitutions.push_back (p.clone ()); + + // we setup a dummy implict self argument + SubstitutionArg placeholder_self (&get_substs ().front (), nullptr); + used_arguments.get_mappings ().push_back (placeholder_self); +} + +TypeBoundPredicate::TypeBoundPredicate ( + DefId reference, std::vector subst, Location locus) + : SubstitutionRef ({}, SubstitutionArgumentMappings::error ()), + reference (reference), locus (locus), error_flag (false) +{ + substitutions.clear (); + for (const auto &p : subst) + substitutions.push_back (p.clone ()); + + // we setup a dummy implict self argument + SubstitutionArg placeholder_self (&get_substs ().front (), nullptr); + used_arguments.get_mappings ().push_back (placeholder_self); +} + +TypeBoundPredicate::TypeBoundPredicate (const TypeBoundPredicate &other) + : SubstitutionRef ({}, SubstitutionArgumentMappings::error ()), + reference (other.reference), locus (other.locus), + error_flag (other.error_flag) +{ + substitutions.clear (); + for (const auto &p : other.get_substs ()) + substitutions.push_back (p.clone ()); + + std::vector mappings; + for (size_t i = 0; i < other.used_arguments.get_mappings ().size (); i++) + { + const SubstitutionArg &oa = other.used_arguments.get_mappings ().at (i); + SubstitutionArg arg (oa); + mappings.push_back (std::move (arg)); + } + + // we need to remap the argument mappings based on this copied constructor + std::vector copied_arg_mappings; + size_t i = 0; + for (const auto &m : other.used_arguments.get_mappings ()) + { + TyTy::BaseType *argument + = m.get_tyty () == nullptr ? nullptr : m.get_tyty ()->clone (); + SubstitutionArg c (&substitutions.at (i++), argument); + copied_arg_mappings.push_back (std::move (c)); + } + + used_arguments + = SubstitutionArgumentMappings (copied_arg_mappings, + other.used_arguments.get_locus ()); +} + +TypeBoundPredicate & +TypeBoundPredicate::operator= (const TypeBoundPredicate &other) +{ + reference = other.reference; + locus = other.locus; + error_flag = other.error_flag; + used_arguments = SubstitutionArgumentMappings::error (); + + substitutions.clear (); + for (const auto &p : other.get_substs ()) + substitutions.push_back (p.clone ()); + + std::vector mappings; + for (size_t i = 0; i < other.used_arguments.get_mappings ().size (); i++) + { + const SubstitutionArg &oa = other.used_arguments.get_mappings ().at (i); + SubstitutionArg arg (oa); + mappings.push_back (std::move (arg)); + } + + // we need to remap the argument mappings based on this copied constructor + std::vector copied_arg_mappings; + size_t i = 0; + for (const auto &m : other.used_arguments.get_mappings ()) + { + TyTy::BaseType *argument + = m.get_tyty () == nullptr ? nullptr : m.get_tyty ()->clone (); + SubstitutionArg c (&substitutions.at (i++), argument); + copied_arg_mappings.push_back (std::move (c)); + } + + used_arguments + = SubstitutionArgumentMappings (copied_arg_mappings, + other.used_arguments.get_locus ()); + + return *this; +} + +TypeBoundPredicate +TypeBoundPredicate::error () +{ + auto p = TypeBoundPredicate (UNKNOWN_DEFID, {}, Location ()); + p.error_flag = true; + return p; +} + +std::string +TypeBoundPredicate::as_string () const +{ + return get ()->as_string () + subst_as_string (); +} + +std::string +TypeBoundPredicate::as_name () const +{ + return get ()->get_name () + subst_as_string (); +} + +const Resolver::TraitReference * +TypeBoundPredicate::get () const +{ + auto context = Resolver::TypeCheckContext::get (); + + Resolver::TraitReference *ref = nullptr; + bool ok = context->lookup_trait_reference (reference, &ref); + rust_assert (ok); + + return ref; +} + +std::string +TypeBoundPredicate::get_name () const +{ + return get ()->get_name (); +} + +bool +TypeBoundPredicate::is_object_safe (bool emit_error, Location locus) const +{ + const Resolver::TraitReference *trait = get (); + rust_assert (trait != nullptr); + return trait->is_object_safe (emit_error, locus); +} + +void +TypeBoundPredicate::apply_generic_arguments (HIR::GenericArgs *generic_args) +{ + // we need to get the substitutions argument mappings but also remember that + // we have an implicit Self argument which we must be careful to respect + rust_assert (!used_arguments.is_empty ()); + rust_assert (!substitutions.empty ()); + + // now actually perform a substitution + used_arguments = get_mappings_from_generic_args (*generic_args); + + error_flag |= used_arguments.is_error (); + auto &subst_mappings = used_arguments; + for (auto &sub : get_substs ()) + { + SubstitutionArg arg = SubstitutionArg::error (); + bool ok + = subst_mappings.get_argument_for_symbol (sub.get_param_ty (), &arg); + if (ok && arg.get_tyty () != nullptr) + sub.fill_param_ty (subst_mappings, subst_mappings.get_locus ()); + } +} + +bool +TypeBoundPredicate::contains_item (const std::string &search) const +{ + auto trait_ref = get (); + const Resolver::TraitItemReference *trait_item_ref = nullptr; + return trait_ref->lookup_trait_item (search, &trait_item_ref); +} + +TypeBoundPredicateItem +TypeBoundPredicate::lookup_associated_item (const std::string &search) const +{ + auto trait_ref = get (); + const Resolver::TraitItemReference *trait_item_ref = nullptr; + if (!trait_ref->lookup_trait_item (search, &trait_item_ref)) + return TypeBoundPredicateItem::error (); + + return TypeBoundPredicateItem (this, trait_item_ref); +} + +TypeBoundPredicateItem +TypeBoundPredicate::lookup_associated_item ( + const Resolver::TraitItemReference *ref) const +{ + return lookup_associated_item (ref->get_identifier ()); +} + +BaseType * +TypeBoundPredicateItem::get_tyty_for_receiver (const TyTy::BaseType *receiver) +{ + TyTy::BaseType *trait_item_tyty = get_raw_item ()->get_tyty (); + if (parent->get_substitution_arguments ().is_empty ()) + return trait_item_tyty; + + const Resolver::TraitItemReference *tref = get_raw_item (); + bool is_associated_type = tref->get_trait_item_type (); + if (is_associated_type) + return trait_item_tyty; + + // set up the self mapping + SubstitutionArgumentMappings gargs = parent->get_substitution_arguments (); + rust_assert (!gargs.is_empty ()); + + // setup the adjusted mappings + std::vector adjusted_mappings; + for (size_t i = 0; i < gargs.get_mappings ().size (); i++) + { + auto &mapping = gargs.get_mappings ().at (i); + + bool is_implicit_self = i == 0; + TyTy::BaseType *argument + = is_implicit_self ? receiver->clone () : mapping.get_tyty (); + + SubstitutionArg arg (mapping.get_param_mapping (), argument); + adjusted_mappings.push_back (std::move (arg)); + } + + SubstitutionArgumentMappings adjusted (adjusted_mappings, gargs.get_locus (), + gargs.get_subst_cb (), + true /* trait-mode-flag */); + return Resolver::SubstMapperInternal::Resolve (trait_item_tyty, adjusted); +} +bool +TypeBoundPredicate::is_error () const +{ + auto context = Resolver::TypeCheckContext::get (); + + Resolver::TraitReference *ref = nullptr; + bool ok = context->lookup_trait_reference (reference, &ref); + + return !ok || error_flag; +} + +BaseType * +TypeBoundPredicate::handle_substitions ( + SubstitutionArgumentMappings subst_mappings) +{ + for (auto &sub : get_substs ()) + { + if (sub.get_param_ty () == nullptr) + continue; + + ParamType *p = sub.get_param_ty (); + BaseType *r = p->resolve (); + BaseType *s = Resolver::SubstMapperInternal::Resolve (r, subst_mappings); + + p->set_ty_ref (s->get_ty_ref ()); + } + + // FIXME more error handling at some point + // used_arguments = subst_mappings; + // error_flag |= used_arguments.is_error (); + + return nullptr; +} + +bool +TypeBoundPredicate::requires_generic_args () const +{ + if (is_error ()) + return false; + + return substitutions.size () > 1; +} + +// trait item reference + +const Resolver::TraitItemReference * +TypeBoundPredicateItem::get_raw_item () const +{ + return trait_item_ref; +} + +bool +TypeBoundPredicateItem::needs_implementation () const +{ + return !get_raw_item ()->is_optional (); +} + +Location +TypeBoundPredicateItem::get_locus () const +{ + return get_raw_item ()->get_locus (); +} + +// TypeBoundsMappings + +TypeBoundsMappings::TypeBoundsMappings ( + std::vector specified_bounds) + : specified_bounds (specified_bounds) +{} + +std::vector & +TypeBoundsMappings::get_specified_bounds () +{ + return specified_bounds; +} + +const std::vector & +TypeBoundsMappings::get_specified_bounds () const +{ + return specified_bounds; +} + +size_t +TypeBoundsMappings::num_specified_bounds () const +{ + return specified_bounds.size (); +} + +std::string +TypeBoundsMappings::raw_bounds_as_string () const +{ + std::string buf; + for (size_t i = 0; i < specified_bounds.size (); i++) + { + const TypeBoundPredicate &b = specified_bounds.at (i); + bool has_next = (i + 1) < specified_bounds.size (); + buf += b.as_string () + (has_next ? " + " : ""); + } + return buf; +} + +std::string +TypeBoundsMappings::bounds_as_string () const +{ + return "bounds:[" + raw_bounds_as_string () + "]"; +} + +std::string +TypeBoundsMappings::raw_bounds_as_name () const +{ + std::string buf; + for (size_t i = 0; i < specified_bounds.size (); i++) + { + const TypeBoundPredicate &b = specified_bounds.at (i); + bool has_next = (i + 1) < specified_bounds.size (); + buf += b.as_name () + (has_next ? " + " : ""); + } + + return buf; +} + +void +TypeBoundsMappings::add_bound (TypeBoundPredicate predicate) +{ + specified_bounds.push_back (predicate); +} + +} // namespace TyTy +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-tyty-call.cc b/gcc/rust/typecheck/rust-tyty-call.cc new file mode 100644 index 00000000000..1ce82c943f5 --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-call.cc @@ -0,0 +1,263 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-tyty-call.h" +#include "rust-hir-type-check-expr.h" + +namespace Rust { +namespace TyTy { + +void +TypeCheckCallExpr::visit (ADTType &type) +{ + rust_assert (!variant.is_error ()); + if (variant.get_variant_type () != TyTy::VariantDef::VariantType::TUPLE) + { + rust_error_at ( + call.get_locus (), + "expected function, tuple struct or tuple variant, found struct %<%s%>", + type.get_name ().c_str ()); + return; + } + + if (call.num_params () != variant.num_fields ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) call.num_params (), + (unsigned long) variant.num_fields ()); + return; + } + + size_t i = 0; + for (auto &argument : call.get_arguments ()) + { + StructFieldType *field = variant.get_field_at_index (i); + BaseType *field_tyty = field->get_field_type (); + + BaseType *arg = Resolver::TypeCheckExpr::Resolve (argument.get ()); + if (arg->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (argument->get_locus (), + "failed to resolve argument type"); + return; + } + + auto res = Resolver::TypeCheckBase::coercion_site ( + argument->get_mappings ().get_hirid (), field_tyty, arg, + argument->get_locus ()); + if (res->get_kind () == TyTy::TypeKind::ERROR) + { + return; + } + + delete res; + i++; + } + + if (i != call.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) i, (unsigned long) call.num_params ()); + return; + } + + resolved = type.clone (); +} + +void +TypeCheckCallExpr::visit (FnType &type) +{ + type.monomorphize (); + if (call.num_params () != type.num_params ()) + { + if (type.is_varadic ()) + { + if (call.num_params () < type.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) call.num_params (), + (unsigned long) type.num_params ()); + return; + } + } + else + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) call.num_params (), + (unsigned long) type.num_params ()); + return; + } + } + + size_t i = 0; + for (auto &argument : call.get_arguments ()) + { + auto argument_expr_tyty + = Resolver::TypeCheckExpr::Resolve (argument.get ()); + if (argument_expr_tyty->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at ( + argument->get_locus (), + "failed to resolve type for argument expr in CallExpr"); + return; + } + + // it might be a varadic function + if (i < type.num_params ()) + { + auto fnparam = type.param_at (i); + auto resolved_argument_type = Resolver::TypeCheckBase::coercion_site ( + argument->get_mappings ().get_hirid (), fnparam.second, + argument_expr_tyty, argument->get_locus ()); + if (resolved_argument_type->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (argument->get_locus (), + "Type Resolution failure on parameter"); + return; + } + } + + i++; + } + + if (i < call.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) i, (unsigned long) call.num_params ()); + return; + } + + type.monomorphize (); + resolved = type.get_return_type ()->clone (); +} + +void +TypeCheckCallExpr::visit (FnPtr &type) +{ + if (call.num_params () != type.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) call.num_params (), + (unsigned long) type.num_params ()); + return; + } + + size_t i = 0; + for (auto &argument : call.get_arguments ()) + { + auto fnparam = type.param_at (i); + auto argument_expr_tyty + = Resolver::TypeCheckExpr::Resolve (argument.get ()); + if (argument_expr_tyty->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at ( + argument->get_locus (), + "failed to resolve type for argument expr in CallExpr"); + return; + } + + auto resolved_argument_type = Resolver::TypeCheckBase::coercion_site ( + argument->get_mappings ().get_hirid (), fnparam, argument_expr_tyty, + argument->get_locus ()); + if (resolved_argument_type->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (argument->get_locus (), + "Type Resolution failure on parameter"); + return; + } + + i++; + } + + if (i != call.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) i, (unsigned long) call.num_params ()); + return; + } + + resolved = type.get_return_type ()->monomorphized_clone (); +} + +// method call checker + +void +TypeCheckMethodCallExpr::visit (FnType &type) +{ + type.get_self_type ()->unify (adjusted_self); + + // +1 for the receiver self + size_t num_args_to_call = call.num_params () + 1; + if (num_args_to_call != type.num_params ()) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) call.num_params (), + (unsigned long) type.num_params ()); + return; + } + + size_t i = 1; + for (auto &argument : call.get_arguments ()) + { + auto fnparam = type.param_at (i); + auto argument_expr_tyty + = Resolver::TypeCheckExpr::Resolve (argument.get ()); + if (argument_expr_tyty->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at ( + argument->get_locus (), + "failed to resolve type for argument expr in CallExpr"); + return; + } + + auto resolved_argument_type = Resolver::TypeCheckBase::coercion_site ( + argument->get_mappings ().get_hirid (), fnparam.second, + argument_expr_tyty, argument->get_locus ()); + if (resolved_argument_type->get_kind () == TyTy::TypeKind::ERROR) + { + rust_error_at (argument->get_locus (), + "Type Resolution failure on parameter"); + return; + } + + i++; + } + + if (i != num_args_to_call) + { + rust_error_at (call.get_locus (), + "unexpected number of arguments %lu expected %lu", + (unsigned long) i, (unsigned long) call.num_params ()); + return; + } + + type.monomorphize (); + + resolved = type.get_return_type ()->monomorphized_clone (); +} + +} // namespace TyTy +} // namespace Rust diff --git a/gcc/rust/typecheck/rust-tyty-call.h b/gcc/rust/typecheck/rust-tyty-call.h new file mode 100644 index 00000000000..51817e6446d --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-call.h @@ -0,0 +1,147 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_TYTY_CALL +#define RUST_TYTY_CALL + +#include "rust-diagnostics.h" +#include "rust-hir-full.h" +#include "rust-tyty-visitor.h" +#include "rust-tyty.h" +#include "rust-hir-type-check.h" + +namespace Rust { +namespace TyTy { + +class TypeCheckCallExpr : private TyVisitor +{ +public: + static BaseType *go (BaseType *ref, HIR::CallExpr &call, + TyTy::VariantDef &variant, + Resolver::TypeCheckContext *context) + { + TypeCheckCallExpr checker (call, variant, context); + ref->accept_vis (checker); + return checker.resolved; + } + + void visit (InferType &) override { gcc_unreachable (); } + void visit (TupleType &) override { gcc_unreachable (); } + void visit (ArrayType &) override { gcc_unreachable (); } + void visit (SliceType &) override { gcc_unreachable (); } + void visit (BoolType &) override { gcc_unreachable (); } + void visit (IntType &) override { gcc_unreachable (); } + void visit (UintType &) override { gcc_unreachable (); } + void visit (FloatType &) override { gcc_unreachable (); } + void visit (USizeType &) override { gcc_unreachable (); } + void visit (ISizeType &) override { gcc_unreachable (); } + void visit (ErrorType &) override { gcc_unreachable (); } + void visit (CharType &) override { gcc_unreachable (); } + void visit (ReferenceType &) override { gcc_unreachable (); } + void visit (PointerType &) override { gcc_unreachable (); } + void visit (ParamType &) override { gcc_unreachable (); } + void visit (StrType &) override { gcc_unreachable (); } + void visit (NeverType &) override { gcc_unreachable (); } + void visit (PlaceholderType &) override { gcc_unreachable (); } + void visit (ProjectionType &) override { gcc_unreachable (); } + void visit (DynamicObjectType &) override { gcc_unreachable (); } + void visit (ClosureType &type) override { gcc_unreachable (); } + + // tuple-structs + void visit (ADTType &type) override; + + // call fns + void visit (FnType &type) override; + void visit (FnPtr &type) override; + +private: + TypeCheckCallExpr (HIR::CallExpr &c, TyTy::VariantDef &variant, + Resolver::TypeCheckContext *context) + : resolved (new TyTy::ErrorType (c.get_mappings ().get_hirid ())), call (c), + variant (variant), context (context), + mappings (Analysis::Mappings::get ()) + {} + + BaseType *resolved; + HIR::CallExpr &call; + TyTy::VariantDef &variant; + Resolver::TypeCheckContext *context; + Analysis::Mappings *mappings; +}; + +class TypeCheckMethodCallExpr : private TyVisitor +{ +public: + // Resolve the Method parameters and return back the return type + static BaseType *go (BaseType *ref, HIR::MethodCallExpr &call, + TyTy::BaseType *adjusted_self, + Resolver::TypeCheckContext *context) + { + TypeCheckMethodCallExpr checker (call, adjusted_self, context); + ref->accept_vis (checker); + return checker.resolved; + } + + void visit (InferType &) override { gcc_unreachable (); } + void visit (TupleType &) override { gcc_unreachable (); } + void visit (ArrayType &) override { gcc_unreachable (); } + void visit (SliceType &) override { gcc_unreachable (); } + void visit (BoolType &) override { gcc_unreachable (); } + void visit (IntType &) override { gcc_unreachable (); } + void visit (UintType &) override { gcc_unreachable (); } + void visit (FloatType &) override { gcc_unreachable (); } + void visit (USizeType &) override { gcc_unreachable (); } + void visit (ISizeType &) override { gcc_unreachable (); } + void visit (ErrorType &) override { gcc_unreachable (); } + void visit (ADTType &) override { gcc_unreachable (); }; + void visit (CharType &) override { gcc_unreachable (); } + void visit (ReferenceType &) override { gcc_unreachable (); } + void visit (PointerType &) override { gcc_unreachable (); } + void visit (ParamType &) override { gcc_unreachable (); } + void visit (StrType &) override { gcc_unreachable (); } + void visit (NeverType &) override { gcc_unreachable (); } + void visit (PlaceholderType &) override { gcc_unreachable (); } + void visit (ProjectionType &) override { gcc_unreachable (); } + void visit (DynamicObjectType &) override { gcc_unreachable (); } + + // FIXME + void visit (FnPtr &type) override { gcc_unreachable (); } + + // call fns + void visit (FnType &type) override; + void visit (ClosureType &type) override { gcc_unreachable (); } + +private: + TypeCheckMethodCallExpr (HIR::MethodCallExpr &c, + TyTy::BaseType *adjusted_self, + Resolver::TypeCheckContext *context) + : resolved (nullptr), call (c), adjusted_self (adjusted_self), + context (context), mappings (Analysis::Mappings::get ()) + {} + + BaseType *resolved; + HIR::MethodCallExpr &call; + TyTy::BaseType *adjusted_self; + Resolver::TypeCheckContext *context; + Analysis::Mappings *mappings; +}; + +} // namespace TyTy +} // namespace Rust + +#endif // RUST_TYTY_CALL diff --git a/gcc/rust/typecheck/rust-tyty-cmp.h b/gcc/rust/typecheck/rust-tyty-cmp.h new file mode 100644 index 00000000000..07d1dea7464 --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-cmp.h @@ -0,0 +1,1554 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_TYTY_CMP_H +#define RUST_TYTY_CMP_H + +#include "rust-diagnostics.h" +#include "rust-tyty.h" +#include "rust-tyty-visitor.h" +#include "rust-hir-map.h" +#include "rust-hir-type-check.h" + +namespace Rust { +namespace TyTy { + +class BaseCmp : public TyConstVisitor +{ +public: + virtual bool can_eq (const BaseType *other) + { + if (other->get_kind () == TypeKind::PARAM) + { + const ParamType *p = static_cast (other); + other = p->resolve (); + } + if (other->get_kind () == TypeKind::PLACEHOLDER) + { + const PlaceholderType *p = static_cast (other); + if (p->can_resolve ()) + { + other = p->resolve (); + } + } + if (other->get_kind () == TypeKind::PROJECTION) + { + const ProjectionType *p = static_cast (other); + other = p->get (); + } + + other->accept_vis (*this); + return ok; + } + + virtual void visit (const TupleType &type) override + { + ok = false; + + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const ADTType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const InferType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const FnType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const FnPtr &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const ArrayType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const SliceType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const BoolType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const IntType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const UintType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const USizeType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const ISizeType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const FloatType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const ErrorType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const CharType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const ReferenceType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const PointerType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const StrType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const NeverType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const ProjectionType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const PlaceholderType &type) override + { + // it is ok for types to can eq to a placeholder + ok = true; + } + + virtual void visit (const ParamType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const DynamicObjectType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + + virtual void visit (const ClosureType &type) override + { + ok = false; + if (emit_error_flag) + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus + = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + } + +protected: + BaseCmp (const BaseType *base, bool emit_errors) + : mappings (Analysis::Mappings::get ()), + context (Resolver::TypeCheckContext::get ()), ok (false), + emit_error_flag (emit_errors) + {} + + Analysis::Mappings *mappings; + Resolver::TypeCheckContext *context; + + bool ok; + bool emit_error_flag; + +private: + /* Returns a pointer to the ty that created this rule. */ + virtual const BaseType *get_base () const = 0; +}; + +class InferCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + InferCmp (const InferType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const BoolType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const IntType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const UintType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const USizeType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const ISizeType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const FloatType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () == TyTy::InferType::InferTypeKind::FLOAT); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const ArrayType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const SliceType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const ADTType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const TupleType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const InferType &type) override + { + switch (base->get_infer_kind ()) + { + case InferType::InferTypeKind::GENERAL: + ok = true; + return; + + case InferType::InferTypeKind::INTEGRAL: { + if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL) + { + ok = true; + return; + } + else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL) + { + ok = true; + return; + } + } + break; + + case InferType::InferTypeKind::FLOAT: { + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + ok = true; + return; + } + else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL) + { + ok = true; + return; + } + } + break; + } + + BaseCmp::visit (type); + } + + void visit (const CharType &type) override + { + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + } + + void visit (const ReferenceType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const PointerType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const ParamType &) override { ok = true; } + + void visit (const DynamicObjectType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + + void visit (const ClosureType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + ok = true; + return; + } + + BaseCmp::visit (type); + } + +private: + const BaseType *get_base () const override { return base; } + const InferType *base; +}; + +class FnCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + FnCmp (const FnType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () == InferType::InferTypeKind::GENERAL; + } + + void visit (const FnType &type) override + { + if (base->num_params () != type.num_params ()) + { + BaseCmp::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto a = base->param_at (i).second; + auto b = type.param_at (i).second; + + if (!a->can_eq (b, emit_error_flag)) + { + emit_error_flag = false; + BaseCmp::visit (type); + return; + } + } + + if (!base->get_return_type ()->can_eq (type.get_return_type (), + emit_error_flag)) + { + emit_error_flag = false; + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const FnType *base; +}; + +class FnptrCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + FnptrCmp (const FnPtr *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + + void visit (const FnPtr &type) override + { + if (base->num_params () != type.num_params ()) + { + BaseCmp::visit (type); + return; + } + + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + if (!this_ret_type->can_eq (other_ret_type, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i); + if (!this_param->can_eq (other_param, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + } + + ok = true; + } + + void visit (const FnType &type) override + { + if (base->num_params () != type.num_params ()) + { + BaseCmp::visit (type); + return; + } + + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + if (!this_ret_type->can_eq (other_ret_type, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i).second; + if (!this_param->can_eq (other_param, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const FnPtr *base; +}; + +class ClosureCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + ClosureCmp (const ClosureType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const ClosureType *base; +}; + +class ArrayCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + ArrayCmp (const ArrayType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const ArrayType &type) override + { + // check base type + const BaseType *base_element = base->get_element_type (); + const BaseType *other_element = type.get_element_type (); + if (!base_element->can_eq (other_element, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const ArrayType *base; +}; + +class SliceCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + SliceCmp (const SliceType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const SliceType &type) override + { + // check base type + const BaseType *base_element = base->get_element_type (); + const BaseType *other_element = type.get_element_type (); + if (!base_element->can_eq (other_element, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const SliceType *base; +}; + +class BoolCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + BoolCmp (const BoolType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const BoolType &type) override { ok = true; } + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () == InferType::InferTypeKind::GENERAL; + } + +private: + const BaseType *get_base () const override { return base; } + const BoolType *base; +}; + +class IntCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + IntCmp (const IntType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT; + } + + void visit (const IntType &type) override + { + ok = type.get_int_kind () == base->get_int_kind (); + } + +private: + const BaseType *get_base () const override { return base; } + const IntType *base; +}; + +class UintCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + UintCmp (const UintType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT; + } + + void visit (const UintType &type) override + { + ok = type.get_uint_kind () == base->get_uint_kind (); + } + +private: + const BaseType *get_base () const override { return base; } + const UintType *base; +}; + +class FloatCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + FloatCmp (const FloatType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () != InferType::InferTypeKind::INTEGRAL; + } + + void visit (const FloatType &type) override + { + ok = type.get_float_kind () == base->get_float_kind (); + } + +private: + const BaseType *get_base () const override { return base; } + const FloatType *base; +}; + +class ADTCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + ADTCmp (const ADTType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const ADTType &type) override + { + if (base->get_adt_kind () != type.get_adt_kind ()) + { + BaseCmp::visit (type); + return; + } + + if (base->get_identifier ().compare (type.get_identifier ()) != 0) + { + BaseCmp::visit (type); + return; + } + + if (base->number_of_variants () != type.number_of_variants ()) + { + BaseCmp::visit (type); + return; + } + + for (size_t i = 0; i < type.number_of_variants (); ++i) + { + TyTy::VariantDef *a = base->get_variants ().at (i); + TyTy::VariantDef *b = type.get_variants ().at (i); + + if (a->num_fields () != b->num_fields ()) + { + BaseCmp::visit (type); + return; + } + + for (size_t j = 0; j < a->num_fields (); j++) + { + TyTy::StructFieldType *base_field = a->get_field_at_index (j); + TyTy::StructFieldType *other_field = b->get_field_at_index (j); + + TyTy::BaseType *this_field_ty = base_field->get_field_type (); + TyTy::BaseType *other_field_ty = other_field->get_field_type (); + + if (!this_field_ty->can_eq (other_field_ty, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + } + } + + ok = true; + } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const ADTType *base; +}; + +class TupleCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + TupleCmp (const TupleType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const TupleType &type) override + { + if (base->num_fields () != type.num_fields ()) + { + BaseCmp::visit (type); + return; + } + + for (size_t i = 0; i < base->num_fields (); i++) + { + BaseType *bo = base->get_field (i); + BaseType *fo = type.get_field (i); + + if (!bo->can_eq (fo, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + } + + ok = true; + } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const TupleType *base; +}; + +class USizeCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + USizeCmp (const USizeType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT; + } + + void visit (const USizeType &type) override { ok = true; } + +private: + const BaseType *get_base () const override { return base; } + const USizeType *base; +}; + +class ISizeCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + ISizeCmp (const ISizeType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT; + } + + void visit (const ISizeType &type) override { ok = true; } + +private: + const BaseType *get_base () const override { return base; } + const ISizeType *base; +}; + +class CharCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + CharCmp (const CharType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const InferType &type) override + { + ok = type.get_infer_kind () == InferType::InferTypeKind::GENERAL; + } + + void visit (const CharType &type) override { ok = true; } + +private: + const BaseType *get_base () const override { return base; } + const CharType *base; +}; + +class ReferenceCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + ReferenceCmp (const ReferenceType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const ReferenceType &type) override + { + auto base_type = base->get_base (); + auto other_base_type = type.get_base (); + + bool mutability_match = base->is_mutable () == type.is_mutable (); + if (!mutability_match) + { + BaseCmp::visit (type); + return; + } + + if (!base_type->can_eq (other_base_type, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const ReferenceType *base; +}; + +class PointerCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + PointerCmp (const PointerType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const PointerType &type) override + { + auto base_type = base->get_base (); + auto other_base_type = type.get_base (); + + // rust is permissive about mutablity here you can always go from mutable to + // immutable but not the otherway round + bool mutability_ok = base->is_mutable () ? type.is_mutable () : true; + if (!mutability_ok) + { + BaseCmp::visit (type); + return; + } + + if (!base_type->can_eq (other_base_type, emit_error_flag)) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const PointerType *base; +}; + +class ParamCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + ParamCmp (const ParamType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + // param types are a placeholder we shouldn't have cases where we unify + // against it. eg: struct foo { a: T }; When we invoke it we can do either: + // + // foo{ a: 123 }. + // Then this enforces the i32 type to be referenced on the + // field via an hirid. + // + // rust also allows for a = foo{a:123}; Where we can use an Inference Variable + // to handle the typing of the struct + bool can_eq (const BaseType *other) override + { + if (!base->can_resolve ()) + return BaseCmp::can_eq (other); + + auto lookup = base->resolve (); + return lookup->can_eq (other, emit_error_flag); + } + + // imagine the case where we have: + // struct Foo(T); + // Then we declare a generic impl block + // impl Foo { ... } + // both of these types are compatible so we mostly care about the number of + // generic arguments + void visit (const ParamType &) override { ok = true; } + + void visit (const InferType &) override { ok = true; } + + void visit (const FnType &) override { ok = true; } + + void visit (const FnPtr &) override { ok = true; } + + void visit (const ADTType &) override { ok = true; } + + void visit (const ArrayType &) override { ok = true; } + + void visit (const SliceType &) override { ok = true; } + + void visit (const BoolType &) override { ok = true; } + + void visit (const IntType &) override { ok = true; } + + void visit (const UintType &) override { ok = true; } + + void visit (const USizeType &) override { ok = true; } + + void visit (const ISizeType &) override { ok = true; } + + void visit (const FloatType &) override { ok = true; } + + void visit (const CharType &) override { ok = true; } + + void visit (const ReferenceType &) override { ok = true; } + + void visit (const PointerType &) override { ok = true; } + + void visit (const StrType &) override { ok = true; } + + void visit (const NeverType &) override { ok = true; } + + void visit (const DynamicObjectType &) override { ok = true; } + + void visit (const PlaceholderType &type) override + { + ok = base->get_symbol ().compare (type.get_symbol ()) == 0; + } + +private: + const BaseType *get_base () const override { return base; } + const ParamType *base; +}; + +class StrCmp : public BaseCmp +{ + // FIXME we will need a enum for the StrType like ByteBuf etc.. + using Rust::TyTy::BaseCmp::visit; + +public: + StrCmp (const StrType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const StrType &type) override { ok = true; } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const StrType *base; +}; + +class NeverCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + NeverCmp (const NeverType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const NeverType &type) override { ok = true; } + + void visit (const InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseCmp::visit (type); + return; + } + + ok = true; + } + +private: + const BaseType *get_base () const override { return base; } + const NeverType *base; +}; + +class PlaceholderCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + PlaceholderCmp (const PlaceholderType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + bool can_eq (const BaseType *other) override + { + if (!base->can_resolve ()) + return BaseCmp::can_eq (other); + + BaseType *lookup = base->resolve (); + return lookup->can_eq (other, emit_error_flag); + } + + void visit (const TupleType &) override { ok = true; } + + void visit (const ADTType &) override { ok = true; } + + void visit (const InferType &) override { ok = true; } + + void visit (const FnType &) override { ok = true; } + + void visit (const FnPtr &) override { ok = true; } + + void visit (const ArrayType &) override { ok = true; } + + void visit (const BoolType &) override { ok = true; } + + void visit (const IntType &) override { ok = true; } + + void visit (const UintType &) override { ok = true; } + + void visit (const USizeType &) override { ok = true; } + + void visit (const ISizeType &) override { ok = true; } + + void visit (const FloatType &) override { ok = true; } + + void visit (const ErrorType &) override { ok = true; } + + void visit (const CharType &) override { ok = true; } + + void visit (const ReferenceType &) override { ok = true; } + + void visit (const ParamType &) override { ok = true; } + + void visit (const StrType &) override { ok = true; } + + void visit (const NeverType &) override { ok = true; } + + void visit (const SliceType &) override { ok = true; } + +private: + const BaseType *get_base () const override { return base; } + + const PlaceholderType *base; +}; + +class DynamicCmp : public BaseCmp +{ + using Rust::TyTy::BaseCmp::visit; + +public: + DynamicCmp (const DynamicObjectType *base, bool emit_errors) + : BaseCmp (base, emit_errors), base (base) + {} + + void visit (const DynamicObjectType &type) override + { + if (base->num_specified_bounds () != type.num_specified_bounds ()) + { + BaseCmp::visit (type); + return; + } + + Location ref_locus = mappings->lookup_location (type.get_ref ()); + ok = base->bounds_compatible (type, ref_locus, false); + } + +private: + const BaseType *get_base () const override { return base; } + + const DynamicObjectType *base; +}; + +} // namespace TyTy +} // namespace Rust + +#endif // RUST_TYTY_CMP_H diff --git a/gcc/rust/typecheck/rust-tyty-rules.h b/gcc/rust/typecheck/rust-tyty-rules.h new file mode 100644 index 00000000000..77d912a5921 --- /dev/null +++ b/gcc/rust/typecheck/rust-tyty-rules.h @@ -0,0 +1,1584 @@ +// Copyright (C) 2020-2022 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_TYTY_RULES +#define RUST_TYTY_RULES + +#include "rust-diagnostics.h" +#include "rust-tyty.h" +#include "rust-tyty-visitor.h" +#include "rust-hir-map.h" +#include "rust-hir-type-check.h" + +namespace Rust { +namespace TyTy { + +/* Rules specify how to unify two Ty. For example, the result of unifying the + two tuples (u64, A) and (B, i64) would be (u64, i64). + + Performing a unification requires a double dispatch. To illustrate, suppose + we want to unify `ty1` and `ty2`. Here's what it looks like: + 1. The caller calls `ty1.unify(ty2)`. This is the first dispatch. + 2. `ty1` creates a rule specific to its type(e.g. TupleRules). + 3. The rule calls `ty2.accept_vis(rule)`. This is the second dispatch. + 4. `ty2` calls `rule.visit(*this)`, which will method-overload to the + correct implementation at compile time. + + The nice thing about Rules is that they seperate unification logic from the + representation of Ty. To support unifying a new Ty, implement its + `accept_vis` and `unify` method to pass the unification request to Rules. + Then, create a new `XXXRules` class and implement one `visit` method for + every Ty it can unify with. */ +class BaseRules : public TyVisitor +{ +public: + virtual ~BaseRules () {} + + /* Unify two ty. Returns a pointer to the newly-created unified ty, or nullptr + if the two types cannot be unified. The caller is responsible for releasing + the memory of the returned ty. + + This method is meant to be used internally by Ty. If you're trying to unify + two ty, you can simply call `unify` on ty themselves. */ + virtual BaseType *unify (BaseType *other) + { + if (other->get_kind () == TypeKind::PARAM) + { + ParamType *p = static_cast (other); + other = p->resolve (); + } + else if (other->get_kind () == TypeKind::PLACEHOLDER) + { + PlaceholderType *p = static_cast (other); + if (p->can_resolve ()) + { + other = p->resolve (); + return get_base ()->unify (other); + } + } + else if (other->get_kind () == TypeKind::PROJECTION) + { + ProjectionType *p = static_cast (other); + other = p->get (); + return get_base ()->unify (other); + } + + other->accept_vis (*this); + if (resolved->get_kind () == TyTy::TypeKind::ERROR) + return resolved; + + resolved->append_reference (get_base ()->get_ref ()); + resolved->append_reference (other->get_ref ()); + for (auto ref : get_base ()->get_combined_refs ()) + resolved->append_reference (ref); + for (auto ref : other->get_combined_refs ()) + resolved->append_reference (ref); + + other->append_reference (resolved->get_ref ()); + other->append_reference (get_base ()->get_ref ()); + get_base ()->append_reference (resolved->get_ref ()); + get_base ()->append_reference (other->get_ref ()); + + bool result_resolved = resolved->get_kind () != TyTy::TypeKind::INFER; + bool result_is_infer_var = resolved->get_kind () == TyTy::TypeKind::INFER; + bool results_is_non_general_infer_var + = (result_is_infer_var + && (static_cast (resolved))->get_infer_kind () + != TyTy::InferType::GENERAL); + if (result_resolved || results_is_non_general_infer_var) + { + for (auto &ref : resolved->get_combined_refs ()) + { + TyTy::BaseType *ref_tyty = nullptr; + bool ok = context->lookup_type (ref, &ref_tyty); + if (!ok) + continue; + + // if any of the types are inference variables lets fix them + if (ref_tyty->get_kind () == TyTy::TypeKind::INFER) + { + context->insert_type ( + Analysis::NodeMapping (mappings->get_current_crate (), + UNKNOWN_NODEID, ref, + UNKNOWN_LOCAL_DEFID), + resolved->clone ()); + } + } + } + return resolved; + } + + virtual void visit (TupleType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ADTType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (InferType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (FnType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (FnPtr &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ArrayType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (SliceType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (BoolType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (IntType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (UintType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (USizeType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ISizeType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (FloatType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ErrorType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (CharType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ReferenceType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (PointerType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ParamType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (StrType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (NeverType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (PlaceholderType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ProjectionType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (DynamicObjectType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + + virtual void visit (ClosureType &type) override + { + Location ref_locus = mappings->lookup_location (type.get_ref ()); + Location base_locus = mappings->lookup_location (get_base ()->get_ref ()); + RichLocation r (ref_locus); + r.add_range (base_locus); + rust_error_at (r, "expected [%s] got [%s]", + get_base ()->as_string ().c_str (), + type.as_string ().c_str ()); + } + +protected: + BaseRules (BaseType *base) + : mappings (Analysis::Mappings::get ()), + context (Resolver::TypeCheckContext::get ()), + resolved (new ErrorType (base->get_ref (), base->get_ref ())) + {} + + Analysis::Mappings *mappings; + Resolver::TypeCheckContext *context; + + /* Temporary storage for the result of a unification. + We could return the result directly instead of storing it in the rule + object, but that involves modifying the visitor pattern to accommodate + the return value, which is too complex. */ + BaseType *resolved; + +private: + /* Returns a pointer to the ty that created this rule. */ + virtual BaseType *get_base () = 0; +}; + +class InferRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + InferRules (InferType *base) : BaseRules (base), base (base) {} + + void visit (BoolType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (IntType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (UintType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (USizeType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (ISizeType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () + == TyTy::InferType::InferTypeKind::INTEGRAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (FloatType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL) + || (base->get_infer_kind () == TyTy::InferType::InferTypeKind::FLOAT); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (ArrayType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (SliceType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (ADTType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (TupleType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (InferType &type) override + { + switch (base->get_infer_kind ()) + { + case InferType::InferTypeKind::GENERAL: + resolved = type.clone (); + return; + + case InferType::InferTypeKind::INTEGRAL: { + if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL) + { + resolved = type.clone (); + return; + } + else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL) + { + resolved = base->clone (); + return; + } + } + break; + + case InferType::InferTypeKind::FLOAT: { + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + resolved = type.clone (); + return; + } + else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL) + { + resolved = base->clone (); + return; + } + } + break; + } + + BaseRules::visit (type); + } + + void visit (CharType &type) override + { + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + } + + void visit (ReferenceType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (PointerType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (ParamType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (DynamicObjectType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + + void visit (ClosureType &type) override + { + bool is_valid + = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL); + if (is_valid) + { + resolved = type.clone (); + return; + } + + BaseRules::visit (type); + } + +private: + BaseType *get_base () override { return base; } + + InferType *base; +}; + +class FnRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + FnRules (FnType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnType &type) override + { + if (base->num_params () != type.num_params ()) + { + BaseRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto a = base->param_at (i).second; + auto b = type.param_at (i).second; + + auto unified_param = a->unify (b); + if (unified_param == nullptr) + { + BaseRules::visit (type); + return; + } + } + + auto unified_return + = base->get_return_type ()->unify (type.get_return_type ()); + if (unified_return == nullptr) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + FnType *base; +}; + +class FnptrRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + FnptrRules (FnPtr *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnPtr &type) override + { + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + auto unified_result = this_ret_type->unify (other_ret_type); + if (unified_result == nullptr + || unified_result->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + + if (base->num_params () != type.num_params ()) + { + BaseRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i); + auto unified_param = this_param->unify (other_param); + if (unified_param == nullptr + || unified_param->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FnType &type) override + { + auto this_ret_type = base->get_return_type (); + auto other_ret_type = type.get_return_type (); + auto unified_result = this_ret_type->unify (other_ret_type); + if (unified_result == nullptr + || unified_result->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + + if (base->num_params () != type.num_params ()) + { + BaseRules::visit (type); + return; + } + + for (size_t i = 0; i < base->num_params (); i++) + { + auto this_param = base->param_at (i); + auto other_param = type.param_at (i).second; + auto unified_param = this_param->unify (other_param); + if (unified_param == nullptr + || unified_param->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + FnPtr *base; +}; + +class ClosureRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + ClosureRules (ClosureType *base) : BaseRules (base), base (base) {} + + // TODO + +private: + BaseType *get_base () override { return base; } + + ClosureType *base; +}; + +class ArrayRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + ArrayRules (ArrayType *base) : BaseRules (base), base (base) {} + + void visit (ArrayType &type) override + { + // check base type + auto base_resolved + = base->get_element_type ()->unify (type.get_element_type ()); + if (base_resolved == nullptr) + { + BaseRules::visit (type); + return; + } + + resolved + = new ArrayType (type.get_ref (), type.get_ty_ref (), + type.get_ident ().locus, type.get_capacity_expr (), + TyVar (base_resolved->get_ref ())); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + ArrayType *base; +}; + +class SliceRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + SliceRules (SliceType *base) : BaseRules (base), base (base) {} + + void visit (SliceType &type) override + { + // check base type + auto base_resolved + = base->get_element_type ()->unify (type.get_element_type ()); + if (base_resolved == nullptr) + { + BaseRules::visit (type); + return; + } + + resolved = new SliceType (type.get_ref (), type.get_ty_ref (), + type.get_ident ().locus, + TyVar (base_resolved->get_ref ())); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + SliceType *base; +}; + +class BoolRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + BoolRules (BoolType *base) : BaseRules (base), base (base) {} + + void visit (BoolType &type) override + { + resolved = new BoolType (type.get_ref (), type.get_ty_ref ()); + } + + void visit (InferType &type) override + { + switch (type.get_infer_kind ()) + { + case InferType::InferTypeKind::GENERAL: + resolved = base->clone (); + break; + + default: + BaseRules::visit (type); + break; + } + } + +private: + BaseType *get_base () override { return base; } + + BoolType *base; +}; + +class IntRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + IntRules (IntType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (IntType &type) override + { + if (type.get_int_kind () != base->get_int_kind ()) + { + BaseRules::visit (type); + return; + } + + resolved + = new IntType (type.get_ref (), type.get_ty_ref (), type.get_int_kind ()); + } + +private: + BaseType *get_base () override { return base; } + + IntType *base; +}; + +class UintRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + UintRules (UintType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (UintType &type) override + { + if (type.get_uint_kind () != base->get_uint_kind ()) + { + BaseRules::visit (type); + return; + } + + resolved = new UintType (type.get_ref (), type.get_ty_ref (), + type.get_uint_kind ()); + } + +private: + BaseType *get_base () override { return base; } + + UintType *base; +}; + +class FloatRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + FloatRules (FloatType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (FloatType &type) override + { + if (type.get_float_kind () != base->get_float_kind ()) + { + BaseRules::visit (type); + return; + } + + resolved = new FloatType (type.get_ref (), type.get_ty_ref (), + type.get_float_kind ()); + } + +private: + BaseType *get_base () override { return base; } + + FloatType *base; +}; + +class ADTRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + ADTRules (ADTType *base) : BaseRules (base), base (base) {} + + void visit (ADTType &type) override + { + if (base->get_adt_kind () != type.get_adt_kind ()) + { + BaseRules::visit (type); + return; + } + + if (base->get_identifier ().compare (type.get_identifier ()) != 0) + { + BaseRules::visit (type); + return; + } + + if (base->number_of_variants () != type.number_of_variants ()) + { + BaseRules::visit (type); + return; + } + + for (size_t i = 0; i < type.number_of_variants (); ++i) + { + TyTy::VariantDef *a = base->get_variants ().at (i); + TyTy::VariantDef *b = type.get_variants ().at (i); + + if (a->num_fields () != b->num_fields ()) + { + BaseRules::visit (type); + return; + } + + for (size_t j = 0; j < a->num_fields (); j++) + { + TyTy::StructFieldType *base_field = a->get_field_at_index (j); + TyTy::StructFieldType *other_field = b->get_field_at_index (j); + + TyTy::BaseType *this_field_ty = base_field->get_field_type (); + TyTy::BaseType *other_field_ty = other_field->get_field_type (); + + BaseType *unified_ty = this_field_ty->unify (other_field_ty); + if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) + return; + } + } + + // generic args for the unit-struct case + if (type.is_unit () && base->is_unit ()) + { + rust_assert (type.get_num_substitutions () + == base->get_num_substitutions ()); + + for (size_t i = 0; i < type.get_num_substitutions (); i++) + { + auto &a = base->get_substs ().at (i); + auto &b = type.get_substs ().at (i); + + auto pa = a.get_param_ty (); + auto pb = b.get_param_ty (); + + auto res = pa->unify (pb); + if (res->get_kind () == TyTy::TypeKind::ERROR) + { + return; + } + } + } + + resolved = type.clone (); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + ADTType *base; +}; + +class TupleRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + TupleRules (TupleType *base) : BaseRules (base), base (base) {} + + void visit (TupleType &type) override + { + if (base->num_fields () != type.num_fields ()) + { + BaseRules::visit (type); + return; + } + + std::vector fields; + for (size_t i = 0; i < base->num_fields (); i++) + { + BaseType *bo = base->get_field (i); + BaseType *fo = type.get_field (i); + + BaseType *unified_ty = bo->unify (fo); + if (unified_ty->get_kind () == TyTy::TypeKind::ERROR) + return; + + fields.push_back (TyVar (unified_ty->get_ref ())); + } + + resolved = new TyTy::TupleType (type.get_ref (), type.get_ty_ref (), + type.get_ident ().locus, fields); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + TupleType *base; +}; + +class USizeRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + USizeRules (USizeType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (USizeType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + USizeType *base; +}; + +class ISizeRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + ISizeRules (ISizeType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + // cant assign a float inference variable + if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (ISizeType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + ISizeType *base; +}; + +class CharRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + CharRules (CharType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + + void visit (CharType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + CharType *base; +}; + +class ReferenceRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + ReferenceRules (ReferenceType *base) : BaseRules (base), base (base) {} + + void visit (ReferenceType &type) override + { + auto base_type = base->get_base (); + auto other_base_type = type.get_base (); + + TyTy::BaseType *base_resolved = base_type->unify (other_base_type); + if (base_resolved == nullptr + || base_resolved->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + + // rust is permissive about mutablity here you can always go from mutable to + // immutable but not the otherway round + bool mutability_ok = base->is_mutable () ? type.is_mutable () : true; + if (!mutability_ok) + { + BaseRules::visit (type); + return; + } + + resolved = new ReferenceType (base->get_ref (), base->get_ty_ref (), + TyVar (base_resolved->get_ref ()), + base->mutability ()); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + ReferenceType *base; +}; + +class PointerRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + PointerRules (PointerType *base) : BaseRules (base), base (base) {} + + void visit (PointerType &type) override + { + auto base_type = base->get_base (); + auto other_base_type = type.get_base (); + + TyTy::BaseType *base_resolved = base_type->unify (other_base_type); + if (base_resolved == nullptr + || base_resolved->get_kind () == TypeKind::ERROR) + { + BaseRules::visit (type); + return; + } + + // rust is permissive about mutablity here you can always go from mutable to + // immutable but not the otherway round + bool mutability_ok = base->is_mutable () ? type.is_mutable () : true; + if (!mutability_ok) + { + BaseRules::visit (type); + return; + } + + resolved = new PointerType (base->get_ref (), base->get_ty_ref (), + TyVar (base_resolved->get_ref ()), + base->mutability ()); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + resolved->set_ref (type.get_ref ()); + } + +private: + BaseType *get_base () override { return base; } + + PointerType *base; +}; + +class ParamRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + ParamRules (ParamType *base) : BaseRules (base), base (base) {} + + // param types are a placeholder we shouldn't have cases where we unify + // against it. eg: struct foo { a: T }; When we invoke it we can do either: + // + // foo{ a: 123 }. + // Then this enforces the i32 type to be referenced on the + // field via an hirid. + // + // rust also allows for a = foo{a:123}; Where we can use an Inference Variable + // to handle the typing of the struct + BaseType *unify (BaseType *other) override final + { + if (!base->can_resolve ()) + return BaseRules::unify (other); + + auto lookup = base->resolve (); + return lookup->unify (other); + } + + void visit (ParamType &type) override + { + if (base->get_symbol ().compare (type.get_symbol ()) != 0) + { + BaseRules::visit (type); + return; + } + + resolved = type.clone (); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + } + +private: + BaseType *get_base () override { return base; } + + ParamType *base; +}; + +class StrRules : public BaseRules +{ + // FIXME we will need a enum for the StrType like ByteBuf etc.. + using Rust::TyTy::BaseRules::visit; + +public: + StrRules (StrType *base) : BaseRules (base), base (base) {} + + void visit (StrType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + StrType *base; +}; + +class NeverRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + NeverRules (NeverType *base) : BaseRules (base), base (base) {} + + void visit (NeverType &type) override { resolved = type.clone (); } + +private: + BaseType *get_base () override { return base; } + + NeverType *base; +}; + +class PlaceholderRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + PlaceholderRules (PlaceholderType *base) : BaseRules (base), base (base) {} + + BaseType *unify (BaseType *other) override final + { + if (!base->can_resolve ()) + return BaseRules::unify (other); + + BaseType *lookup = base->resolve (); + return lookup->unify (other); + } + + void visit (PlaceholderType &type) override + { + if (base->get_symbol ().compare (type.get_symbol ()) != 0) + { + BaseRules::visit (type); + return; + } + + resolved = type.clone (); + } + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + } + +private: + BaseType *get_base () override { return base; } + + PlaceholderType *base; +}; + +class DynamicRules : public BaseRules +{ + using Rust::TyTy::BaseRules::visit; + +public: + DynamicRules (DynamicObjectType *base) : BaseRules (base), base (base) {} + + void visit (InferType &type) override + { + if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + } + + void visit (DynamicObjectType &type) override + { + if (base->num_specified_bounds () != type.num_specified_bounds ()) + { + BaseRules::visit (type); + return; + } + + Location ref_locus = mappings->lookup_location (type.get_ref ()); + if (!base->bounds_compatible (type, ref_locus, true)) + { + BaseRules::visit (type); + return; + } + + resolved = base->clone (); + } + +private: + BaseType *get_base () override { return base; } + + DynamicObjectType *base; +}; + +} // namespace TyTy +} // namespace Rust + +#endif // RUST_TYTY_RULES