about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-08-04 13:17:34 +0000
committerbors <bors@rust-lang.org>2023-08-04 13:17:34 +0000
commitc59bd2dc3f05692f92b5b1c76f3d08d116e63422 (patch)
tree266845f86df280cb340c2881445cc3db5df74574
parente37ec7262c095e38069671a175e5b2bf4fbbbff2 (diff)
parentcc5664c5a22f0b8fcd268f5b5866bc91dacdda6a (diff)
downloadrust-c59bd2dc3f05692f92b5b1c76f3d08d116e63422.tar.gz
rust-c59bd2dc3f05692f92b5b1c76f3d08d116e63422.zip
Auto merge of #15390 - HKalbasi:mir, r=HKalbasi
Improve mir interpreter performance by caching
-rw-r--r--crates/hir-def/src/body/lower.rs41
-rw-r--r--crates/hir-ty/src/mir/eval.rs366
-rw-r--r--crates/hir-ty/src/mir/eval/shim.rs6
3 files changed, 260 insertions, 153 deletions
diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs
index c8d1ca4fa70..3df43576482 100644
--- a/crates/hir-def/src/body/lower.rs
+++ b/crates/hir-def/src/body/lower.rs
@@ -313,20 +313,7 @@ impl ExprCollector<'_> {
                 let body = self.collect_labelled_block_opt(label, e.loop_body());
                 self.alloc_expr(Expr::Loop { body, label }, syntax_ptr)
             }
-            ast::Expr::WhileExpr(e) => {
-                // Desugar `while <cond> { <body> }` to
-                // `loop { if <cond> { <body> } else { break } }`
-                let label = e.label().map(|label| self.collect_label(label));
-                let body = self.collect_labelled_block_opt(label, e.loop_body());
-                let condition = self.collect_expr_opt(e.condition());
-                let break_expr =
-                    self.alloc_expr(Expr::Break { expr: None, label: None }, syntax_ptr.clone());
-                let if_expr = self.alloc_expr(
-                    Expr::If { condition, then_branch: body, else_branch: Some(break_expr) },
-                    syntax_ptr.clone(),
-                );
-                self.alloc_expr(Expr::Loop { body: if_expr, label }, syntax_ptr)
-            }
+            ast::Expr::WhileExpr(e) => self.collect_while_loop(syntax_ptr, e),
             ast::Expr::ForExpr(e) => self.collect_for_loop(syntax_ptr, e),
             ast::Expr::CallExpr(e) => {
                 let is_rustc_box = {
@@ -738,6 +725,32 @@ impl ExprCollector<'_> {
         expr_id
     }
 
+    /// Desugar `ast::WhileExpr` from: `[opt_ident]: while <cond> <body>` into:
+    /// ```ignore (pseudo-rust)
+    /// [opt_ident]: loop {
+    ///   if <cond> {
+    ///     <body>
+    ///   }
+    ///   else {
+    ///     break;
+    ///   }
+    /// }
+    /// ```
+    /// FIXME: Rustc wraps the condition in a construct equivalent to `{ let _t = <cond>; _t }`
+    /// to preserve drop semantics. We should probably do the same in future.
+    fn collect_while_loop(&mut self, syntax_ptr: AstPtr<ast::Expr>, e: ast::WhileExpr) -> ExprId {
+        let label = e.label().map(|label| self.collect_label(label));
+        let body = self.collect_labelled_block_opt(label, e.loop_body());
+        let condition = self.collect_expr_opt(e.condition());
+        let break_expr =
+            self.alloc_expr(Expr::Break { expr: None, label: None }, syntax_ptr.clone());
+        let if_expr = self.alloc_expr(
+            Expr::If { condition, then_branch: body, else_branch: Some(break_expr) },
+            syntax_ptr.clone(),
+        );
+        self.alloc_expr(Expr::Loop { body: if_expr, label }, syntax_ptr)
+    }
+
     /// Desugar `ast::ForExpr` from: `[opt_ident]: for <pat> in <head> <body>` into:
     /// ```ignore (pseudo-rust)
     /// match IntoIterator::into_iter(<head>) {
diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs
index 177c1f7ac6d..9e30eed56f3 100644
--- a/crates/hir-ty/src/mir/eval.rs
+++ b/crates/hir-ty/src/mir/eval.rs
@@ -1,6 +1,13 @@
 //! This module provides a MIR interpreter, which is used in const eval.
 
-use std::{borrow::Cow, cell::RefCell, collections::HashMap, fmt::Write, iter, mem, ops::Range};
+use std::{
+    borrow::Cow,
+    cell::RefCell,
+    collections::{HashMap, HashSet},
+    fmt::Write,
+    iter, mem,
+    ops::Range,
+};
 
 use base_db::{CrateId, FileId};
 use chalk_ir::Mutability;
@@ -39,7 +46,8 @@ use crate::{
 
 use super::{
     return_slot, AggregateKind, BasicBlockId, BinOp, CastKind, LocalId, MirBody, MirLowerError,
-    MirSpan, Operand, Place, ProjectionElem, Rvalue, StatementKind, TerminatorKind, UnOp,
+    MirSpan, Operand, Place, PlaceElem, ProjectionElem, Rvalue, StatementKind, TerminatorKind,
+    UnOp,
 };
 
 mod shim;
@@ -120,13 +128,18 @@ impl TlsData {
 }
 
 struct StackFrame {
-    body: Arc<MirBody>,
     locals: Locals,
     destination: Option<BasicBlockId>,
     prev_stack_ptr: usize,
     span: (MirSpan, DefWithBodyId),
 }
 
+#[derive(Clone)]
+enum MirOrDynIndex {
+    Mir(Arc<MirBody>),
+    Dyn(usize),
+}
+
 pub struct Evaluator<'a> {
     db: &'a dyn HirDatabase,
     trait_env: Arc<TraitEnvironment>,
@@ -145,6 +158,17 @@ pub struct Evaluator<'a> {
     stdout: Vec<u8>,
     stderr: Vec<u8>,
     layout_cache: RefCell<FxHashMap<Ty, Arc<Layout>>>,
+    projected_ty_cache: RefCell<FxHashMap<(Ty, PlaceElem), Ty>>,
+    not_special_fn_cache: RefCell<FxHashSet<FunctionId>>,
+    mir_or_dyn_index_cache: RefCell<FxHashMap<(FunctionId, Substitution), MirOrDynIndex>>,
+    /// Constantly dropping and creating `Locals` is very costly. We store
+    /// old locals that we normaly want to drop here, to reuse their allocations
+    /// later.
+    unused_locals_store: RefCell<FxHashMap<DefWithBodyId, Vec<Locals>>>,
+    cached_ptr_size: usize,
+    cached_fn_trait_func: Option<FunctionId>,
+    cached_fn_mut_trait_func: Option<FunctionId>,
+    cached_fn_once_trait_func: Option<FunctionId>,
     crate_id: CrateId,
     // FIXME: This is a workaround, see the comment on `interpret_mir`
     assert_placeholder_ty_is_unused: bool,
@@ -477,6 +501,10 @@ impl DropFlags {
         }
         self.need_drop.remove(p)
     }
+
+    fn clear(&mut self) {
+        self.need_drop.clear();
+    }
 }
 
 #[derive(Debug)]
@@ -550,6 +578,26 @@ impl Evaluator<'_> {
             execution_limit: EXECUTION_LIMIT,
             memory_limit: 1000_000_000, // 2GB, 1GB for stack and 1GB for heap
             layout_cache: RefCell::new(HashMap::default()),
+            projected_ty_cache: RefCell::new(HashMap::default()),
+            not_special_fn_cache: RefCell::new(HashSet::default()),
+            mir_or_dyn_index_cache: RefCell::new(HashMap::default()),
+            unused_locals_store: RefCell::new(HashMap::default()),
+            cached_ptr_size: match db.target_data_layout(crate_id) {
+                Some(it) => it.pointer_size.bytes_usize(),
+                None => 8,
+            },
+            cached_fn_trait_func: db
+                .lang_item(crate_id, LangItem::Fn)
+                .and_then(|x| x.as_trait())
+                .and_then(|x| db.trait_data(x).method_by_name(&name![call])),
+            cached_fn_mut_trait_func: db
+                .lang_item(crate_id, LangItem::FnMut)
+                .and_then(|x| x.as_trait())
+                .and_then(|x| db.trait_data(x).method_by_name(&name![call_mut])),
+            cached_fn_once_trait_func: db
+                .lang_item(crate_id, LangItem::FnOnce)
+                .and_then(|x| x.as_trait())
+                .and_then(|x| db.trait_data(x).method_by_name(&name![call_once])),
         }
     }
 
@@ -570,10 +618,34 @@ impl Evaluator<'_> {
     }
 
     fn ptr_size(&self) -> usize {
-        match self.db.target_data_layout(self.crate_id) {
-            Some(it) => it.pointer_size.bytes_usize(),
-            None => 8,
+        self.cached_ptr_size
+    }
+
+    fn projected_ty(&self, ty: Ty, proj: PlaceElem) -> Ty {
+        let pair = (ty, proj);
+        if let Some(r) = self.projected_ty_cache.borrow().get(&pair) {
+            return r.clone();
         }
+        let (ty, proj) = pair;
+        let r = proj.projected_ty(
+            ty.clone(),
+            self.db,
+            |c, subst, f| {
+                let (def, _) = self.db.lookup_intern_closure(c.into());
+                let infer = self.db.infer(def);
+                let (captures, _) = infer.closure_info(&c);
+                let parent_subst = ClosureSubst(subst).parent_subst();
+                captures
+                    .get(f)
+                    .expect("broken closure field")
+                    .ty
+                    .clone()
+                    .substitute(Interner, parent_subst)
+            },
+            self.crate_id,
+        );
+        self.projected_ty_cache.borrow_mut().insert((ty, proj), r.clone());
+        r
     }
 
     fn place_addr_and_ty_and_metadata<'a>(
@@ -586,23 +658,7 @@ impl Evaluator<'_> {
         let mut metadata: Option<IntervalOrOwned> = None; // locals are always sized
         for proj in &*p.projection {
             let prev_ty = ty.clone();
-            ty = proj.projected_ty(
-                ty,
-                self.db,
-                |c, subst, f| {
-                    let (def, _) = self.db.lookup_intern_closure(c.into());
-                    let infer = self.db.infer(def);
-                    let (captures, _) = infer.closure_info(&c);
-                    let parent_subst = ClosureSubst(subst).parent_subst();
-                    captures
-                        .get(f)
-                        .expect("broken closure field")
-                        .ty
-                        .clone()
-                        .substitute(Interner, parent_subst)
-                },
-                self.crate_id,
-            );
+            ty = self.projected_ty(ty, proj.clone());
             match proj {
                 ProjectionElem::Deref => {
                     metadata = if self.size_align_of(&ty, locals)?.is_none() {
@@ -756,18 +812,18 @@ impl Evaluator<'_> {
             return Err(MirEvalError::StackOverflow);
         }
         let mut current_block_idx = body.start_block;
-        let (mut locals, prev_stack_ptr) = self.create_locals_for_body(body.clone(), None)?;
+        let (mut locals, prev_stack_ptr) = self.create_locals_for_body(&body, None)?;
         self.fill_locals_for_body(&body, &mut locals, args)?;
         let prev_code_stack = mem::take(&mut self.code_stack);
         let span = (MirSpan::Unknown, body.owner);
-        self.code_stack.push(StackFrame { body, locals, destination: None, prev_stack_ptr, span });
+        self.code_stack.push(StackFrame { locals, destination: None, prev_stack_ptr, span });
         'stack: loop {
             let Some(mut my_stack_frame) = self.code_stack.pop() else {
                 not_supported!("missing stack frame");
             };
             let e = (|| {
                 let mut locals = &mut my_stack_frame.locals;
-                let body = &*my_stack_frame.body;
+                let body = locals.body.clone();
                 loop {
                     let current_block = &body.basic_blocks[current_block_idx];
                     if let Some(it) = self.execution_limit.checked_sub(1) {
@@ -836,7 +892,7 @@ impl Evaluator<'_> {
                             locals.drop_flags.add_place(destination.clone());
                             if let Some(stack_frame) = stack_frame {
                                 self.code_stack.push(my_stack_frame);
-                                current_block_idx = stack_frame.body.start_block;
+                                current_block_idx = stack_frame.locals.body.start_block;
                                 self.code_stack.push(stack_frame);
                                 return Ok(None);
                             } else {
@@ -877,18 +933,24 @@ impl Evaluator<'_> {
                     let my_code_stack = mem::replace(&mut self.code_stack, prev_code_stack);
                     let mut error_stack = vec![];
                     for frame in my_code_stack.into_iter().rev() {
-                        if let DefWithBodyId::FunctionId(f) = frame.body.owner {
+                        if let DefWithBodyId::FunctionId(f) = frame.locals.body.owner {
                             error_stack.push((Either::Left(f), frame.span.0, frame.span.1));
                         }
                     }
                     return Err(MirEvalError::InFunction(Box::new(e), error_stack));
                 }
             };
+            let return_interval = my_stack_frame.locals.ptr[return_slot()];
+            self.unused_locals_store
+                .borrow_mut()
+                .entry(my_stack_frame.locals.body.owner)
+                .or_default()
+                .push(my_stack_frame.locals);
             match my_stack_frame.destination {
                 None => {
                     self.code_stack = prev_code_stack;
                     self.stack_depth_limit += 1;
-                    return Ok(my_stack_frame.locals.ptr[return_slot()].get(self)?.to_vec());
+                    return Ok(return_interval.get(self)?.to_vec());
                 }
                 Some(bb) => {
                     // We don't support const promotion, so we can't truncate the stack yet.
@@ -926,39 +988,45 @@ impl Evaluator<'_> {
 
     fn create_locals_for_body(
         &mut self,
-        body: Arc<MirBody>,
+        body: &Arc<MirBody>,
         destination: Option<Interval>,
     ) -> Result<(Locals, usize)> {
         let mut locals =
-            Locals { ptr: ArenaMap::new(), body: body.clone(), drop_flags: DropFlags::default() };
-        let (locals_ptr, stack_size) = {
+            match self.unused_locals_store.borrow_mut().entry(body.owner).or_default().pop() {
+                None => Locals {
+                    ptr: ArenaMap::new(),
+                    body: body.clone(),
+                    drop_flags: DropFlags::default(),
+                },
+                Some(mut l) => {
+                    l.drop_flags.clear();
+                    l.body = body.clone();
+                    l
+                }
+            };
+        let stack_size = {
             let mut stack_ptr = self.stack.len();
-            let addr = body
-                .locals
-                .iter()
-                .map(|(id, it)| {
-                    if id == return_slot() {
-                        if let Some(destination) = destination {
-                            return Ok((id, destination));
-                        }
-                    }
-                    let (size, align) = self.size_align_of_sized(
-                        &it.ty,
-                        &locals,
-                        "no unsized local in extending stack",
-                    )?;
-                    while stack_ptr % align != 0 {
-                        stack_ptr += 1;
+            for (id, it) in body.locals.iter() {
+                if id == return_slot() {
+                    if let Some(destination) = destination {
+                        locals.ptr.insert(id, destination);
+                        continue;
                     }
-                    let my_ptr = stack_ptr;
-                    stack_ptr += size;
-                    Ok((id, Interval { addr: Stack(my_ptr), size }))
-                })
-                .collect::<Result<ArenaMap<LocalId, _>>>()?;
-            let stack_size = stack_ptr - self.stack.len();
-            (addr, stack_size)
+                }
+                let (size, align) = self.size_align_of_sized(
+                    &it.ty,
+                    &locals,
+                    "no unsized local in extending stack",
+                )?;
+                while stack_ptr % align != 0 {
+                    stack_ptr += 1;
+                }
+                let my_ptr = stack_ptr;
+                stack_ptr += size;
+                locals.ptr.insert(id, Interval { addr: Stack(my_ptr), size });
+            }
+            stack_ptr - self.stack.len()
         };
-        locals.ptr = locals_ptr;
         let prev_stack_pointer = self.stack.len();
         if stack_size > self.memory_limit {
             return Err(MirEvalError::Panic(format!(
@@ -1693,6 +1761,11 @@ impl Evaluator<'_> {
     }
 
     fn size_align_of(&self, ty: &Ty, locals: &Locals) -> Result<Option<(usize, usize)>> {
+        if let Some(layout) = self.layout_cache.borrow().get(ty) {
+            return Ok(layout
+                .is_sized()
+                .then(|| (layout.size.bytes_usize(), layout.align.abi.bytes() as usize)));
+        }
         if let DefWithBodyId::VariantId(f) = locals.body.owner {
             if let Some((adt, _)) = ty.as_adt() {
                 if AdtId::from(f.parent) == adt {
@@ -1753,16 +1826,15 @@ impl Evaluator<'_> {
     }
 
     fn detect_fn_trait(&self, def: FunctionId) -> Option<FnTrait> {
-        use LangItem::*;
-        let ItemContainerId::TraitId(parent) = self.db.lookup_intern_function(def).container else {
-            return None;
-        };
-        let l = self.db.lang_attr(parent.into())?;
-        match l {
-            FnOnce => Some(FnTrait::FnOnce),
-            FnMut => Some(FnTrait::FnMut),
-            Fn => Some(FnTrait::Fn),
-            _ => None,
+        let def = Some(def);
+        if def == self.cached_fn_trait_func {
+            Some(FnTrait::Fn)
+        } else if def == self.cached_fn_mut_trait_func {
+            Some(FnTrait::FnMut)
+        } else if def == self.cached_fn_once_trait_func {
+            Some(FnTrait::FnOnce)
+        } else {
+            None
         }
     }
 
@@ -2105,6 +2177,40 @@ impl Evaluator<'_> {
         }
     }
 
+    fn get_mir_or_dyn_index(
+        &self,
+        def: FunctionId,
+        generic_args: Substitution,
+        locals: &Locals,
+        span: MirSpan,
+    ) -> Result<MirOrDynIndex> {
+        let pair = (def, generic_args);
+        if let Some(r) = self.mir_or_dyn_index_cache.borrow().get(&pair) {
+            return Ok(r.clone());
+        }
+        let (def, generic_args) = pair;
+        let r = if let Some(self_ty_idx) =
+            is_dyn_method(self.db, self.trait_env.clone(), def, generic_args.clone())
+        {
+            MirOrDynIndex::Dyn(self_ty_idx)
+        } else {
+            let (imp, generic_args) =
+                self.db.lookup_impl_method(self.trait_env.clone(), def, generic_args.clone());
+            let mir_body = self
+                .db
+                .monomorphized_mir_body(imp.into(), generic_args, self.trait_env.clone())
+                .map_err(|e| {
+                    MirEvalError::InFunction(
+                        Box::new(MirEvalError::MirLowerError(imp, e)),
+                        vec![(Either::Left(imp), span, locals.body.owner)],
+                    )
+                })?;
+            MirOrDynIndex::Mir(mir_body)
+        };
+        self.mir_or_dyn_index_cache.borrow_mut().insert((def, generic_args), r.clone());
+        Ok(r)
+    }
+
     fn exec_fn_with_args(
         &mut self,
         def: FunctionId,
@@ -2126,93 +2232,76 @@ impl Evaluator<'_> {
             return Ok(None);
         }
         let arg_bytes = args.iter().map(|it| IntervalOrOwned::Borrowed(it.interval));
-        if let Some(self_ty_idx) =
-            is_dyn_method(self.db, self.trait_env.clone(), def, generic_args.clone())
-        {
-            // In the layout of current possible receiver, which at the moment of writing this code is one of
-            // `&T`, `&mut T`, `Box<T>`, `Rc<T>`, `Arc<T>`, and `Pin<P>` where `P` is one of possible recievers,
-            // the vtable is exactly in the `[ptr_size..2*ptr_size]` bytes. So we can use it without branching on
-            // the type.
-            let first_arg = arg_bytes.clone().next().unwrap();
-            let first_arg = first_arg.get(self)?;
-            let ty =
-                self.vtable_map.ty_of_bytes(&first_arg[self.ptr_size()..self.ptr_size() * 2])?;
-            let mut args_for_target = args.to_vec();
-            args_for_target[0] = IntervalAndTy {
-                interval: args_for_target[0].interval.slice(0..self.ptr_size()),
-                ty: ty.clone(),
-            };
-            let ty = GenericArgData::Ty(ty.clone()).intern(Interner);
-            let generics_for_target = Substitution::from_iter(
-                Interner,
-                generic_args.iter(Interner).enumerate().map(|(i, it)| {
-                    if i == self_ty_idx {
-                        &ty
-                    } else {
-                        it
-                    }
-                }),
-            );
-            return self.exec_fn_with_args(
-                def,
-                &args_for_target,
-                generics_for_target,
+        match self.get_mir_or_dyn_index(def, generic_args.clone(), locals, span)? {
+            MirOrDynIndex::Dyn(self_ty_idx) => {
+                // In the layout of current possible receiver, which at the moment of writing this code is one of
+                // `&T`, `&mut T`, `Box<T>`, `Rc<T>`, `Arc<T>`, and `Pin<P>` where `P` is one of possible recievers,
+                // the vtable is exactly in the `[ptr_size..2*ptr_size]` bytes. So we can use it without branching on
+                // the type.
+                let first_arg = arg_bytes.clone().next().unwrap();
+                let first_arg = first_arg.get(self)?;
+                let ty = self
+                    .vtable_map
+                    .ty_of_bytes(&first_arg[self.ptr_size()..self.ptr_size() * 2])?;
+                let mut args_for_target = args.to_vec();
+                args_for_target[0] = IntervalAndTy {
+                    interval: args_for_target[0].interval.slice(0..self.ptr_size()),
+                    ty: ty.clone(),
+                };
+                let ty = GenericArgData::Ty(ty.clone()).intern(Interner);
+                let generics_for_target = Substitution::from_iter(
+                    Interner,
+                    generic_args.iter(Interner).enumerate().map(|(i, it)| {
+                        if i == self_ty_idx {
+                            &ty
+                        } else {
+                            it
+                        }
+                    }),
+                );
+                return self.exec_fn_with_args(
+                    def,
+                    &args_for_target,
+                    generics_for_target,
+                    locals,
+                    destination,
+                    target_bb,
+                    span,
+                );
+            }
+            MirOrDynIndex::Mir(body) => self.exec_looked_up_function(
+                body,
                 locals,
+                def,
+                arg_bytes,
+                span,
                 destination,
                 target_bb,
-                span,
-            );
+            ),
         }
-        let (imp, generic_args) =
-            self.db.lookup_impl_method(self.trait_env.clone(), def, generic_args);
-        self.exec_looked_up_function(
-            generic_args,
-            locals,
-            imp,
-            arg_bytes,
-            span,
-            destination,
-            target_bb,
-        )
     }
 
     fn exec_looked_up_function(
         &mut self,
-        generic_args: Substitution,
+        mir_body: Arc<MirBody>,
         locals: &Locals,
-        imp: FunctionId,
+        def: FunctionId,
         arg_bytes: impl Iterator<Item = IntervalOrOwned>,
         span: MirSpan,
         destination: Interval,
         target_bb: Option<BasicBlockId>,
     ) -> Result<Option<StackFrame>> {
-        let def = imp.into();
-        let mir_body = self
-            .db
-            .monomorphized_mir_body(def, generic_args, self.trait_env.clone())
-            .map_err(|e| {
-                MirEvalError::InFunction(
-                    Box::new(MirEvalError::MirLowerError(imp, e)),
-                    vec![(Either::Left(imp), span, locals.body.owner)],
-                )
-            })?;
         Ok(if let Some(target_bb) = target_bb {
             let (mut locals, prev_stack_ptr) =
-                self.create_locals_for_body(mir_body.clone(), Some(destination))?;
+                self.create_locals_for_body(&mir_body, Some(destination))?;
             self.fill_locals_for_body(&mir_body, &mut locals, arg_bytes.into_iter())?;
             let span = (span, locals.body.owner);
-            Some(StackFrame {
-                body: mir_body,
-                locals,
-                destination: Some(target_bb),
-                prev_stack_ptr,
-                span,
-            })
+            Some(StackFrame { locals, destination: Some(target_bb), prev_stack_ptr, span })
         } else {
             let result = self.interpret_mir(mir_body, arg_bytes).map_err(|e| {
                 MirEvalError::InFunction(
                     Box::new(e),
-                    vec![(Either::Left(imp), span, locals.body.owner)],
+                    vec![(Either::Left(def), span, locals.body.owner)],
                 )
             })?;
             destination.write_from_bytes(self, &result)?;
@@ -2384,16 +2473,15 @@ impl Evaluator<'_> {
             // we can ignore drop in them.
             return Ok(());
         };
-        let (impl_drop_candidate, subst) = self.db.lookup_impl_method(
-            self.trait_env.clone(),
-            drop_fn,
-            Substitution::from1(Interner, ty.clone()),
-        );
-        if impl_drop_candidate != drop_fn {
+
+        let generic_args = Substitution::from1(Interner, ty.clone());
+        if let Ok(MirOrDynIndex::Mir(body)) =
+            self.get_mir_or_dyn_index(drop_fn, generic_args, locals, span)
+        {
             self.exec_looked_up_function(
-                subst,
+                body,
                 locals,
-                impl_drop_candidate,
+                drop_fn,
                 [IntervalOrOwned::Owned(addr.to_bytes())].into_iter(),
                 span,
                 Interval { addr: Address::Invalid(0), size: 0 },
diff --git a/crates/hir-ty/src/mir/eval/shim.rs b/crates/hir-ty/src/mir/eval/shim.rs
index 6c562fe3093..b2e29fd34b5 100644
--- a/crates/hir-ty/src/mir/eval/shim.rs
+++ b/crates/hir-ty/src/mir/eval/shim.rs
@@ -36,6 +36,9 @@ impl Evaluator<'_> {
         destination: Interval,
         span: MirSpan,
     ) -> Result<bool> {
+        if self.not_special_fn_cache.borrow().contains(&def) {
+            return Ok(false);
+        }
         let function_data = self.db.function_data(def);
         let is_intrinsic = match &function_data.abi {
             Some(abi) => *abi == Interned::new_str("rust-intrinsic"),
@@ -137,8 +140,11 @@ impl Evaluator<'_> {
                     self.exec_clone(def, args, self_ty.clone(), locals, destination, span)?;
                     return Ok(true);
                 }
+                // Return early to prevent caching clone as non special fn.
+                return Ok(false);
             }
         }
+        self.not_special_fn_cache.borrow_mut().insert(def);
         Ok(false)
     }