Skip to content
32 changes: 32 additions & 0 deletions crates/hir-def/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,38 @@ impl ExprOrPatId {
}
stdx::impl_from!(ExprId, PatId for ExprOrPatId);

// FIXME: Eventually encode this as a single u32 like ExprOrPatId?
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update)]
pub enum TypeRefOrExprId {
TypeRefId(TypeRefId),
ExprId(ExprId),
}

impl TypeRefOrExprId {
pub fn as_type_ref(self) -> Option<TypeRefId> {
match self {
Self::TypeRefId(v) => Some(v),
_ => None,
}
}

pub fn is_type_ref(&self) -> bool {
matches!(self, Self::TypeRefId(_))
}

pub fn as_expr(self) -> Option<ExprId> {
match self {
Self::ExprId(v) => Some(v),
_ => None,
}
}

pub fn is_expr(&self) -> bool {
matches!(self, Self::ExprId(_))
}
}
stdx::impl_from!(TypeRefId, ExprId for TypeRefOrExprId);

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Label {
pub name: Name,
Expand Down
27 changes: 14 additions & 13 deletions crates/hir-def/src/hir/type_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
ExpressionStore,
path::{GenericArg, Path},
},
hir::ExprId,
hir::{ExprId, TypeRefOrExprId},
};

#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
Expand Down Expand Up @@ -192,24 +192,22 @@ impl TypeRef {
TypeRef::Tuple(ThinVec::new())
}

pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(TypeRefId, &TypeRef)) {
pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(TypeRefOrExprId)) {
go(this, f, map);

fn go(
type_ref_id: TypeRefId,
f: &mut impl FnMut(TypeRefId, &TypeRef),
map: &ExpressionStore,
) {
let type_ref = &map[type_ref_id];
f(type_ref_id, type_ref);
match type_ref {
fn go(type_ref_id: TypeRefId, f: &mut impl FnMut(TypeRefOrExprId), map: &ExpressionStore) {
f(type_ref_id.into());
match &map[type_ref_id] {
TypeRef::Fn(fn_) => {
fn_.params.iter().for_each(|&(_, param_type)| go(param_type, f, map))
}
TypeRef::Tuple(types) => types.iter().for_each(|&t| go(t, f, map)),
TypeRef::RawPtr(type_ref, _) | TypeRef::Slice(type_ref) => go(*type_ref, f, map),
TypeRef::Reference(it) => go(it.ty, f, map),
TypeRef::Array(it) => go(it.ty, f, map),
TypeRef::Array(it) => {
go(it.ty, f, map);
f(it.len.expr.into());
}
TypeRef::ImplTrait(bounds) | TypeRef::DynTrait(bounds) => {
for bound in bounds {
match bound {
Expand All @@ -225,7 +223,7 @@ impl TypeRef {
};
}

fn go_path(path: &Path, f: &mut impl FnMut(TypeRefId, &TypeRef), map: &ExpressionStore) {
fn go_path(path: &Path, f: &mut impl FnMut(TypeRefOrExprId), map: &ExpressionStore) {
if let Some(type_ref) = path.type_anchor() {
go(type_ref, f, map);
}
Expand All @@ -236,7 +234,10 @@ impl TypeRef {
GenericArg::Type(type_ref) => {
go(*type_ref, f, map);
}
GenericArg::Const(_) | GenericArg::Lifetime(_) => {}
GenericArg::Const(const_ref) => {
f(const_ref.expr.into());
}
GenericArg::Lifetime(_) => {}
}
}
for binding in args_and_bindings.bindings.iter() {
Expand Down
72 changes: 59 additions & 13 deletions crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use hir_def::{
TupleFieldId, TupleId, TypeOrConstParamId, VariantId,
attrs::AttrFlags,
expr_store::{Body, ExpressionStore, HygieneId, RootExprOrigin, path::Path},
hir::{BindingId, ExprId, ExprOrPatId, LabelId, PatId},
hir::{BindingId, Expr, ExprId, ExprOrPatId, LabelId, PatId, TypeRefOrExprId},
lang_item::LangItems,
layout::Integer,
resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs},
Expand All @@ -54,7 +54,7 @@ use macros::{TypeFoldable, TypeVisitable};
use rustc_ast_ir::Mutability;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_type_ir::{
AliasTyKind, TypeFoldable,
AliasTyKind, InferTy, TypeFoldable,
inherent::{Const as _, IntoKind, Ty as _},
};
use smallvec::SmallVec;
Expand All @@ -65,7 +65,7 @@ use thin_vec::ThinVec;
use crate::{
ImplTraitId, IncorrectGenericsLenKind, PathLoweringDiagnostic, Span, TargetFeatures,
closure_analysis::PlaceBase,
collect_type_inference_vars,
collect_inference_vars,
db::{HirDatabase, InternedOpaqueTyId},
infer::{
callee::DeferredCallResolution,
Expand All @@ -84,10 +84,10 @@ use crate::{
},
method_resolution::CandidateId,
next_solver::{
AliasTy, Const, DbInterner, ErrorGuaranteed, GenericArgs, Region, StoredGenericArg,
StoredGenericArgs, StoredTy, StoredTys, Term, Ty, TyKind, Tys,
AliasTy, Const, DbInterner, ErrorGuaranteed, GenericArgs, Region, StoredConst,
StoredGenericArg, StoredGenericArgs, StoredTy, StoredTys, Term, Ty, TyKind, Tys,
abi::Safety,
infer::{InferCtxt, ObligationInspector, traits::ObligationCause},
infer::{InferCtxt, ObligationInspector, TyOrConstInferVar, traits::ObligationCause},
},
utils::TargetFeatureIsSafeInTarget,
};
Expand Down Expand Up @@ -679,6 +679,7 @@ pub struct InferenceResult {
pub(crate) type_of_pat: ArenaMap<PatId, StoredTy>,
pub(crate) type_of_binding: ArenaMap<BindingId, StoredTy>,
pub(crate) type_of_type_placeholder: FxHashMap<TypeRefId, StoredTy>,
pub(crate) const_of_const_placeholder: FxHashMap<TypeRefOrExprId, StoredConst>,
pub(crate) type_of_opaque: FxHashMap<InternedOpaqueTyId, StoredTy>,

pub(crate) type_mismatches: Option<Box<FxHashMap<ExprOrPatId, TypeMismatch>>>,
Expand Down Expand Up @@ -986,6 +987,7 @@ impl InferenceResult {
type_of_pat: Default::default(),
type_of_binding: Default::default(),
type_of_type_placeholder: Default::default(),
const_of_const_placeholder: Default::default(),
type_of_opaque: Default::default(),
type_mismatches: Default::default(),
skipped_ref_pats: Default::default(),
Expand Down Expand Up @@ -1065,6 +1067,14 @@ impl InferenceResult {
pub fn type_of_type_placeholder<'db>(&self, type_ref: TypeRefId) -> Option<Ty<'db>> {
self.type_of_type_placeholder.get(&type_ref).map(|ty| ty.as_ref())
}
pub fn placeholder_consts<'db>(&self) -> impl Iterator<Item = (TypeRefOrExprId, Const<'db>)> {
self.const_of_const_placeholder
.iter()
.map(|(&type_ref_or_const, const_)| (type_ref_or_const, const_.as_ref()))
}
pub fn const_of_const_placeholder<'db>(&self, expr: TypeRefOrExprId) -> Option<Const<'db>> {
self.const_of_const_placeholder.get(&expr).map(|ty| ty.as_ref())
}
pub fn type_of_expr_or_pat<'db>(&self, id: ExprOrPatId) -> Option<Ty<'db>> {
match id {
ExprOrPatId::ExprId(id) => self.type_of_expr.get(id).map(|it| it.as_ref()),
Expand Down Expand Up @@ -1402,6 +1412,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
type_of_pat,
type_of_binding,
type_of_type_placeholder,
const_of_const_placeholder,
type_of_opaque,
skipped_ref_pats,
type_mismatches,
Expand Down Expand Up @@ -1435,6 +1446,10 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
resolver.resolve_completely(ty);
}
type_of_type_placeholder.shrink_to_fit();
for const_ in const_of_const_placeholder.values_mut() {
resolver.resolve_completely(const_);
}
const_of_const_placeholder.shrink_to_fit();
type_of_opaque.shrink_to_fit();

if let Some(type_mismatches) = type_mismatches {
Expand Down Expand Up @@ -1708,6 +1723,10 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
self.result.type_of_type_placeholder.insert(type_ref, ty.store());
}

fn write_const_placeholder_const(&mut self, expr: TypeRefOrExprId, const_: Const<'db>) {
self.result.const_of_const_placeholder.insert(expr, const_.store());
}

fn write_binding_ty(&mut self, id: BindingId, ty: Ty<'db>) {
self.result.type_of_binding.insert(id, ty.store());
}
Expand Down Expand Up @@ -1770,23 +1789,50 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
lifetime_elision: LifetimeElisionKind<'db>,
span: Span,
) -> Ty<'db> {
use hir_def::hir::TypeRefOrExprId::*;

let ty = self
.with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref));
let ty = self.process_user_written_ty(span, ty);

// Record the association from placeholders' TypeRefId to type variables.
// We only record them if their number matches. This assumes TypeRef::walk and TypeVisitable process the items in the same order.
let type_variables = collect_type_inference_vars(&ty);
let variables = collect_inference_vars(&ty);
let mut placeholder_ids = vec![];
TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| {
if matches!(type_ref, TypeRef::Placeholder) {
placeholder_ids.push(type_ref_id);
TypeRef::walk(type_ref, store, &mut |type_ref_or_expr_id| match type_ref_or_expr_id {
TypeRefId(type_ref_id) => {
if matches!(store[type_ref_id], TypeRef::Placeholder) {
placeholder_ids.push(type_ref_or_expr_id);
}
}
ExprId(expr_id) => {
if matches!(store[expr_id], Expr::Underscore) {
placeholder_ids.push(type_ref_or_expr_id);
}
}
});

if placeholder_ids.len() == type_variables.len() {
for (placeholder_id, type_variable) in placeholder_ids.into_iter().zip(type_variables) {
self.write_type_placeholder_ty(placeholder_id, type_variable);
let interner = self.interner();
if placeholder_ids.len() == variables.len() {
for (placeholder_id, variable) in placeholder_ids.into_iter().zip(variables) {
match (placeholder_id, variable) {
(TypeRefId(idx), TyOrConstInferVar::Ty(ty_vid)) => self
.write_type_placeholder_ty(
idx,
Ty::new_infer(interner, InferTy::TyVar(ty_vid)),
),
(TypeRefId(idx), TyOrConstInferVar::TyInt(int_vid)) => {
self.write_type_placeholder_ty(idx, Ty::new_int_var(interner, int_vid))
}
(TypeRefId(idx), TyOrConstInferVar::TyFloat(float_vid)) => {
self.write_type_placeholder_ty(idx, Ty::new_float_var(interner, float_vid))
}
(_, TyOrConstInferVar::Const(const_vid)) => self.write_const_placeholder_const(
placeholder_id,
Const::new_var(interner, const_vid),
),
_ => {}
}
}
}

Expand Down
28 changes: 18 additions & 10 deletions crates/hir-ty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ use crate::{
PolyFnSig, Region, RegionKind, TraitRef, Ty, TyKind, TypingMode,
abi::Safety,
infer::{
DbInternerInferExt,
DbInternerInferExt, TyOrConstInferVar,
traits::{Obligation, ObligationCause},
},
obligation_ctxt::ObligationCtxt,
Expand Down Expand Up @@ -628,33 +628,41 @@ where
Vec::from_iter(collector.params)
}

struct TypeInferenceVarCollector<'db> {
type_inference_vars: Vec<Ty<'db>>,
struct InferenceVarCollector {
inference_vars: Vec<TyOrConstInferVar>,
}

impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for TypeInferenceVarCollector<'db> {
impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for InferenceVarCollector {
type Result = ();

fn visit_ty(&mut self, ty: Ty<'db>) -> Self::Result {
use crate::rustc_type_ir::Flags;
if ty.is_ty_var() {
self.type_inference_vars.push(ty);
} else if ty.flags().intersects(rustc_type_ir::TypeFlags::HAS_TY_INFER) {
if let Some(infer_var) = TyOrConstInferVar::maybe_from_ty(ty) {
self.inference_vars.push(infer_var);
} else if ty.flags().intersects(
rustc_type_ir::TypeFlags::HAS_TY_INFER | rustc_type_ir::TypeFlags::HAS_CT_INFER,
) {
ty.super_visit_with(self);
} else {
// Fast path: don't visit inner types (e.g. generic arguments) when `flags` indicate
// that there are no placeholders.
}
}

fn visit_const(&mut self, const_: Const<'db>) -> Self::Result {
if let Some(infer_var) = TyOrConstInferVar::maybe_from_const(const_) {
self.inference_vars.push(infer_var);
}
}
}

pub fn collect_type_inference_vars<'db, T>(value: &T) -> Vec<Ty<'db>>
pub fn collect_inference_vars<'db, T>(value: &T) -> Vec<TyOrConstInferVar>
where
T: ?Sized + rustc_type_ir::TypeVisitable<DbInterner<'db>>,
{
let mut collector = TypeInferenceVarCollector { type_inference_vars: vec![] };
let mut collector = InferenceVarCollector { inference_vars: vec![] };
value.visit_with(&mut collector);
collector.type_inference_vars
collector.inference_vars
}

pub fn known_const_to_ast<'db>(
Expand Down
21 changes: 21 additions & 0 deletions crates/hir-ty/src/next_solver/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ impl<'db> TypeVisitable<DbInterner<'db>> for Const<'db> {
}
}

impl<'db> TypeVisitable<DbInterner<'db>> for StoredConst {
fn visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>(
&self,
visitor: &mut V,
) -> V::Result {
self.as_ref().visit_with(visitor)
}
}

impl<'db> TypeSuperVisitable<DbInterner<'db>> for Const<'db> {
fn super_visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>(
&self,
Expand Down Expand Up @@ -213,6 +222,18 @@ impl<'db> TypeFoldable<DbInterner<'db>> for Const<'db> {
}
}

impl<'db> TypeFoldable<DbInterner<'db>> for StoredConst {
fn try_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>(
self,
folder: &mut F,
) -> Result<Self, F::Error> {
Ok(self.as_ref().try_fold_with(folder)?.store())
}
fn fold_with<F: rustc_type_ir::TypeFolder<DbInterner<'db>>>(self, folder: &mut F) -> Self {
self.as_ref().fold_with(folder).store()
}
}

impl<'db> TypeSuperFoldable<DbInterner<'db>> for Const<'db> {
fn try_super_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>(
self,
Expand Down
34 changes: 34 additions & 0 deletions crates/hir-ty/src/next_solver/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,40 @@ pub fn fold_tys<'db, T: TypeFoldable<DbInterner<'db>>>(
t.fold_with(&mut Folder { interner, callback })
}

pub fn fold_tys_and_consts<'db, T: TypeFoldable<DbInterner<'db>>>(
interner: DbInterner<'db>,
t: T,
ty_callback: impl FnMut(Ty<'db>) -> Ty<'db>,
const_callback: impl FnMut(Const<'db>) -> Const<'db>,
) -> T {
struct Folder<'db, F, G> {
interner: DbInterner<'db>,
ty_callback: F,
const_callback: G,
}
impl<'db, F, G> TypeFolder<DbInterner<'db>> for Folder<'db, F, G>
where
F: FnMut(Ty<'db>) -> Ty<'db>,
G: FnMut(Const<'db>) -> Const<'db>,
{
fn cx(&self) -> DbInterner<'db> {
self.interner
}

fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
let t = t.super_fold_with(self);
(self.ty_callback)(t)
}

fn fold_const(&mut self, c: Const<'db>) -> Const<'db> {
let c = c.super_fold_with(self);
(self.const_callback)(c)
}
}

t.fold_with(&mut Folder { interner, ty_callback, const_callback })
}

impl<'db> DbInterner<'db> {
/// Replaces all regions bound by the given `Binder` with the
/// results returned by the closure; the closure is expected to
Expand Down
Loading