about summary refs log tree commit diff
diff options
context:
space:
mode:
authorhkalbasi <hamidrezakalbasi@protonmail.com>2023-03-14 12:14:02 +0330
committerhkalbasi <hamidrezakalbasi@protonmail.com>2023-03-17 13:08:35 +0330
commit7525a38af5ebd9eef404b19a11cefe8a033f9d2d (patch)
tree513e6dae074877a7b2a1c947d526689953698163
parenta063f000ff99989406abd1e6f58a9c2b576ba41a (diff)
downloadrust-7525a38af5ebd9eef404b19a11cefe8a033f9d2d.tar.gz
rust-7525a38af5ebd9eef404b19a11cefe8a033f9d2d.zip
Support evaluating `dyn Trait` methods
-rw-r--r--crates/hir-ty/src/consteval/tests.rs51
-rw-r--r--crates/hir-ty/src/method_resolution.rs39
-rw-r--r--crates/hir-ty/src/mir/eval.rs158
-rw-r--r--crates/hir-ty/src/mir/lower.rs9
4 files changed, 197 insertions, 60 deletions
diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs
index 8a9a5d254df..f7914b578e4 100644
--- a/crates/hir-ty/src/consteval/tests.rs
+++ b/crates/hir-ty/src/consteval/tests.rs
@@ -1009,6 +1009,57 @@ fn function_traits() {
 }
 
 #[test]
+fn dyn_trait() {
+    check_number(
+        r#"
+    //- minicore: coerce_unsized, index, slice
+    trait Foo {
+        fn foo(&self) -> u8 { 10 }
+    }
+    struct S1;
+    struct S2;
+    struct S3;
+    impl Foo for S1 {
+        fn foo(&self) -> u8 { 1 }
+    }
+    impl Foo for S2 {
+        fn foo(&self) -> u8 { 2 }
+    }
+    impl Foo for S3 {}
+    const GOAL: u8 = {
+        let x: &[&dyn Foo] = &[&S1, &S2, &S3];
+        x[0].foo() + x[1].foo() + x[2].foo()
+    };
+        "#,
+        13,
+    );
+    check_number(
+        r#"
+    //- minicore: coerce_unsized, index, slice
+    trait Foo {
+        fn foo(&self) -> i32 { 10 }
+    }
+    trait Bar {
+        fn bar(&self) -> i32 { 20 }
+    }
+
+    struct S;
+    impl Foo for S {
+        fn foo(&self) -> i32 { 200 }
+    }
+    impl Bar for dyn Foo {
+        fn bar(&self) -> i32 { 700 }
+    }
+    const GOAL: i32 = {
+        let x: &dyn Foo = &S;
+        x.bar() + x.foo()
+    };
+        "#,
+        900,
+    );
+}
+
+#[test]
 fn array_and_index() {
     check_number(
         r#"
diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs
index f105c94086c..6244b98104f 100644
--- a/crates/hir-ty/src/method_resolution.rs
+++ b/crates/hir-ty/src/method_resolution.rs
@@ -5,7 +5,7 @@
 use std::{ops::ControlFlow, sync::Arc};
 
 use base_db::{CrateId, Edition};
-use chalk_ir::{cast::Cast, Mutability, TyKind, UniverseIndex};
+use chalk_ir::{cast::Cast, Mutability, TyKind, UniverseIndex, WhereClause};
 use hir_def::{
     data::ImplData, item_scope::ItemScope, lang_item::LangItem, nameres::DefMap, AssocItemId,
     BlockId, ConstId, FunctionId, HasModule, ImplId, ItemContainerId, Lookup, ModuleDefId,
@@ -692,6 +692,38 @@ pub fn lookup_impl_const(
         .unwrap_or((const_id, subs))
 }
 
+/// Checks if the self parameter of `Trait` method is the `dyn Trait` and we should
+/// call the method using the vtable.
+pub fn is_dyn_method(
+    db: &dyn HirDatabase,
+    _env: Arc<TraitEnvironment>,
+    func: FunctionId,
+    fn_subst: Substitution,
+) -> Option<usize> {
+    let ItemContainerId::TraitId(trait_id) = func.lookup(db.upcast()).container else {
+        return None;
+    };
+    let trait_params = db.generic_params(trait_id.into()).type_or_consts.len();
+    let fn_params = fn_subst.len(Interner) - trait_params;
+    let trait_ref = TraitRef {
+        trait_id: to_chalk_trait_id(trait_id),
+        substitution: Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params)),
+    };
+    let self_ty = trait_ref.self_type_parameter(Interner);
+    if let TyKind::Dyn(d) = self_ty.kind(Interner) {
+        let is_my_trait_in_bounds = d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() {
+            // rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter
+            // what the generics are, we are sure that the method is come from the vtable.
+            WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id,
+            _ => false,
+        });
+        if is_my_trait_in_bounds {
+            return Some(fn_params);
+        }
+    }
+    None
+}
+
 /// Looks up the impl method that actually runs for the trait method `func`.
 ///
 /// Returns `func` if it's not a method defined in a trait or the lookup failed.
@@ -701,9 +733,8 @@ pub fn lookup_impl_method(
     func: FunctionId,
     fn_subst: Substitution,
 ) -> (FunctionId, Substitution) {
-    let trait_id = match func.lookup(db.upcast()).container {
-        ItemContainerId::TraitId(id) => id,
-        _ => return (func, fn_subst),
+    let ItemContainerId::TraitId(trait_id) = func.lookup(db.upcast()).container else {
+        return (func, fn_subst)
     };
     let trait_params = db.generic_params(trait_id.into()).type_or_consts.len();
     let fn_params = fn_subst.len(Interner) - trait_params;
diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs
index 88ef92a4ae6..7293156a978 100644
--- a/crates/hir-ty/src/mir/eval.rs
+++ b/crates/hir-ty/src/mir/eval.rs
@@ -23,10 +23,10 @@ use crate::{
     infer::{normalize, PointerCast},
     layout::layout_of_ty,
     mapping::from_chalk,
-    method_resolution::lookup_impl_method,
+    method_resolution::{is_dyn_method, lookup_impl_method},
     traits::FnTrait,
     CallableDefId, Const, ConstScalar, FnDefId, Interner, MemoryMap, Substitution,
-    TraitEnvironment, Ty, TyBuilder, TyExt,
+    TraitEnvironment, Ty, TyBuilder, TyExt, GenericArgData,
 };
 
 use super::{
@@ -34,6 +34,15 @@ use super::{
     Operand, Place, ProjectionElem, Rvalue, StatementKind, Terminator, UnOp,
 };
 
+macro_rules! from_bytes {
+    ($ty:tt, $value:expr) => {
+        ($ty::from_le_bytes(match ($value).try_into() {
+            Ok(x) => x,
+            Err(_) => return Err(MirEvalError::TypeError("mismatched size")),
+        }))
+    };
+}
+
 #[derive(Debug, Default)]
 struct VTableMap {
     ty_to_id: HashMap<Ty, usize>,
@@ -54,6 +63,11 @@ impl VTableMap {
     fn ty(&self, id: usize) -> Result<&Ty> {
         self.id_to_ty.get(id).ok_or(MirEvalError::InvalidVTableId(id))
     }
+
+    fn ty_of_bytes(&self, bytes: &[u8]) -> Result<&Ty> {
+        let id = from_bytes!(usize, bytes);
+        self.ty(id)
+    }
 }
 
 pub struct Evaluator<'a> {
@@ -110,15 +124,6 @@ impl IntervalOrOwned {
     }
 }
 
-macro_rules! from_bytes {
-    ($ty:tt, $value:expr) => {
-        ($ty::from_le_bytes(match ($value).try_into() {
-            Ok(x) => x,
-            Err(_) => return Err(MirEvalError::TypeError("mismatched size")),
-        }))
-    };
-}
-
 impl Address {
     fn from_bytes(x: &[u8]) -> Result<Self> {
         Ok(Address::from_usize(from_bytes!(usize, x)))
@@ -781,7 +786,18 @@ impl Evaluator<'_> {
                                         }
                                         _ => not_supported!("slice unsizing from non pointers"),
                                     },
-                                    TyKind::Dyn(_) => not_supported!("dyn pointer unsize cast"),
+                                    TyKind::Dyn(_) => match &current_ty.data(Interner).kind {
+                                        TyKind::Raw(_, ty) | TyKind::Ref(_, _, ty) => {
+                                            let vtable = self.vtable_map.id(ty.clone());
+                                            let addr =
+                                                self.eval_operand(operand, locals)?.get(&self)?;
+                                            let mut r = Vec::with_capacity(16);
+                                            r.extend(addr.iter().copied());
+                                            r.extend(vtable.to_le_bytes().into_iter());
+                                            Owned(r)
+                                        }
+                                        _ => not_supported!("dyn unsizing from non pointers"),
+                                    },
                                     _ => not_supported!("unknown unsized cast"),
                                 }
                             }
@@ -1227,44 +1243,8 @@ impl Evaluator<'_> {
                 let arg_bytes = args
                     .iter()
                     .map(|x| Ok(self.eval_operand(x, &locals)?.get(&self)?.to_owned()))
-                    .collect::<Result<Vec<_>>>()?
-                    .into_iter();
-                let function_data = self.db.function_data(def);
-                let is_intrinsic = match &function_data.abi {
-                    Some(abi) => *abi == Interned::new_str("rust-intrinsic"),
-                    None => match def.lookup(self.db.upcast()).container {
-                        hir_def::ItemContainerId::ExternBlockId(block) => {
-                            let id = block.lookup(self.db.upcast()).id;
-                            id.item_tree(self.db.upcast())[id.value].abi.as_deref()
-                                == Some("rust-intrinsic")
-                        }
-                        _ => false,
-                    },
-                };
-                let result = if is_intrinsic {
-                    self.exec_intrinsic(
-                        function_data.name.as_text().unwrap_or_default().as_str(),
-                        arg_bytes,
-                        generic_args,
-                        &locals,
-                    )?
-                } else if let Some(x) = self.detect_lang_function(def) {
-                    self.exec_lang_item(x, arg_bytes)?
-                } else {
-                    let (imp, generic_args) = lookup_impl_method(
-                        self.db,
-                        self.trait_env.clone(),
-                        def,
-                        generic_args.clone(),
-                    );
-                    let generic_args = self.subst_filler(&generic_args, &locals);
-                    let def = imp.into();
-                    let mir_body =
-                        self.db.mir_body(def).map_err(|e| MirEvalError::MirLowerError(imp, e))?;
-                    self.interpret_mir(&mir_body, arg_bytes, generic_args)
-                        .map_err(|e| MirEvalError::InFunction(imp, Box::new(e)))?
-                };
-                self.write_memory(dest_addr, &result)?;
+                    .collect::<Result<Vec<_>>>()?;
+                self.exec_fn_with_args(def, arg_bytes, generic_args, locals, dest_addr)?;
             }
             CallableDefId::StructId(id) => {
                 let (size, variant_layout, tag) =
@@ -1284,6 +1264,77 @@ impl Evaluator<'_> {
         Ok(())
     }
 
+    fn exec_fn_with_args(
+        &mut self,
+        def: FunctionId,
+        arg_bytes: Vec<Vec<u8>>,
+        generic_args: Substitution,
+        locals: &Locals<'_>,
+        dest_addr: Address,
+    ) -> Result<()> {
+        let function_data = self.db.function_data(def);
+        let is_intrinsic = match &function_data.abi {
+            Some(abi) => *abi == Interned::new_str("rust-intrinsic"),
+            None => match def.lookup(self.db.upcast()).container {
+                hir_def::ItemContainerId::ExternBlockId(block) => {
+                    let id = block.lookup(self.db.upcast()).id;
+                    id.item_tree(self.db.upcast())[id.value].abi.as_deref()
+                        == Some("rust-intrinsic")
+                }
+                _ => false,
+            },
+        };
+        let result = if is_intrinsic {
+            self.exec_intrinsic(
+                function_data.name.as_text().unwrap_or_default().as_str(),
+                arg_bytes.iter().cloned(),
+                generic_args,
+                &locals,
+            )?
+        } else if let Some(x) = self.detect_lang_function(def) {
+            self.exec_lang_item(x, &arg_bytes)?
+        } else {
+            if let Some(self_ty_idx) =
+                is_dyn_method(self.db, self.trait_env.clone(), def, generic_args.clone())
+            {
+                // In the layout of current possible receiver, which at the moment of writing this code is one of
+                // `&T`, `&mut T`, `Box<T>`, `Rc<T>`, `Arc<T>`, and `Pin<P>` where `P` is one of possible recievers,
+                // the vtable is exactly in the `[ptr_size..2*ptr_size]` bytes. So we can use it without branching on
+                // the type.
+                let ty = self
+                    .vtable_map
+                    .ty_of_bytes(&arg_bytes[0][self.ptr_size()..self.ptr_size() * 2])?;
+                let ty = GenericArgData::Ty(ty.clone()).intern(Interner);
+                let mut args_for_target = arg_bytes;
+                args_for_target[0] = args_for_target[0][0..self.ptr_size()].to_vec();
+                let generics_for_target = Substitution::from_iter(
+                    Interner,
+                    generic_args
+                        .iter(Interner)
+                        .enumerate()
+                        .map(|(i, x)| if i == self_ty_idx { &ty } else { x })
+                );
+                return self.exec_fn_with_args(
+                    def,
+                    args_for_target,
+                    generics_for_target,
+                    locals,
+                    dest_addr,
+                );
+            }
+            let (imp, generic_args) =
+                lookup_impl_method(self.db, self.trait_env.clone(), def, generic_args.clone());
+            let generic_args = self.subst_filler(&generic_args, &locals);
+            let def = imp.into();
+            let mir_body =
+                self.db.mir_body(def).map_err(|e| MirEvalError::MirLowerError(imp, e))?;
+            self.interpret_mir(&mir_body, arg_bytes.iter().cloned(), generic_args)
+                .map_err(|e| MirEvalError::InFunction(imp, Box::new(e)))?
+        };
+        self.write_memory(dest_addr, &result)?;
+        Ok(())
+    }
+
     fn exec_fn_trait(
         &mut self,
         ft: FnTrait,
@@ -1317,12 +1368,9 @@ impl Evaluator<'_> {
         Ok(())
     }
 
-    fn exec_lang_item(
-        &self,
-        x: LangItem,
-        mut args: std::vec::IntoIter<Vec<u8>>,
-    ) -> Result<Vec<u8>> {
+    fn exec_lang_item(&self, x: LangItem, args: &[Vec<u8>]) -> Result<Vec<u8>> {
         use LangItem::*;
+        let mut args = args.iter();
         match x {
             PanicFmt | BeginPanic => Err(MirEvalError::Panic),
             SliceLen => {
diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs
index 7a5ca089420..4fc3c67a6e1 100644
--- a/crates/hir-ty/src/mir/lower.rs
+++ b/crates/hir-ty/src/mir/lower.rs
@@ -230,7 +230,14 @@ impl MirLowerCtx<'_> {
                                     self.lower_const(c, current, place, expr_id.into())?;
                                     return Ok(Some(current))
                                 },
-                                _ => not_supported!("associated functions and types"),
+                                hir_def::AssocItemId::FunctionId(_) => {
+                                    // FnDefs are zero sized, no action is needed.
+                                    return Ok(Some(current))
+                                }
+                                hir_def::AssocItemId::TypeAliasId(_) => {
+                                    // FIXME: If it is unreachable, use proper error instead of `not_supported`.
+                                    not_supported!("associated functions and types")
+                                },
                             }
                         } else if let Some(variant) = self
                             .infer