about summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Walton <pcwalton@mimiga.net>2014-07-01 14:30:33 -0700
committerPatrick Walton <pcwalton@mimiga.net>2014-07-01 14:32:57 -0700
commit454b9d2d1f742d25f95b956649a9a6d92b522da4 (patch)
treeb8afca91bc021fe0f3871bf19a08f570a1cf2847
parent14c0b3ab42a8f963b5e766605f31028e30cd9c0a (diff)
downloadrust-454b9d2d1f742d25f95b956649a9a6d92b522da4.tar.gz
rust-454b9d2d1f742d25f95b956649a9a6d92b522da4.zip
librustc: Fix `expr_use_visitor` (and, transitively, the borrow check)
with overloaded calls.

This enforces the mutability and borrow restrictions around overloaded
calls.

Closes #14774.

[breaking-change]
-rw-r--r--src/librustc/middle/expr_use_visitor.rs107
-rw-r--r--src/test/compile-fail/borrowck-overloaded-call.rs74
2 files changed, 171 insertions, 10 deletions
diff --git a/src/librustc/middle/expr_use_visitor.rs b/src/librustc/middle/expr_use_visitor.rs
index 3e72341ebb0..1e06b3b1fd4 100644
--- a/src/librustc/middle/expr_use_visitor.rs
+++ b/src/librustc/middle/expr_use_visitor.rs
@@ -19,13 +19,14 @@ use middle::def;
 use middle::freevars;
 use middle::pat_util;
 use middle::ty;
-use middle::typeck::MethodCall;
+use middle::typeck::{MethodCall, MethodObject, MethodOrigin, MethodParam};
+use middle::typeck::{MethodStatic};
 use middle::typeck;
-use syntax::ast;
-use syntax::codemap::{Span};
 use util::ppaux::Repr;
 
 use std::gc::Gc;
+use syntax::ast;
+use syntax::codemap::Span;
 
 ///////////////////////////////////////////////////////////////////////////
 // The Delegate trait
@@ -101,6 +102,74 @@ pub enum MutateMode {
     WriteAndRead, // x += y
 }
 
+enum OverloadedCallType {
+    FnOverloadedCall,
+    FnMutOverloadedCall,
+    FnOnceOverloadedCall,
+}
+
+impl OverloadedCallType {
+    fn from_trait_id(tcx: &ty::ctxt, trait_id: ast::DefId)
+                     -> OverloadedCallType {
+        for &(maybe_function_trait, overloaded_call_type) in [
+            (tcx.lang_items.fn_once_trait(), FnOnceOverloadedCall),
+            (tcx.lang_items.fn_mut_trait(), FnMutOverloadedCall),
+            (tcx.lang_items.fn_trait(), FnOverloadedCall)
+        ].iter() {
+            match maybe_function_trait {
+                Some(function_trait) if function_trait == trait_id => {
+                    return overloaded_call_type
+                }
+                _ => continue,
+            }
+        }
+
+        tcx.sess.bug("overloaded call didn't map to known function trait")
+    }
+
+    fn from_method_id(tcx: &ty::ctxt, method_id: ast::DefId)
+                      -> OverloadedCallType {
+        let method_descriptor =
+            match tcx.methods.borrow_mut().find(&method_id) {
+                None => {
+                    tcx.sess.bug("overloaded call method wasn't in method \
+                                  map")
+                }
+                Some(ref method_descriptor) => (*method_descriptor).clone(),
+            };
+        let impl_id = match method_descriptor.container {
+            ty::TraitContainer(_) => {
+                tcx.sess.bug("statically resolved overloaded call method \
+                              belonged to a trait?!")
+            }
+            ty::ImplContainer(impl_id) => impl_id,
+        };
+        let trait_ref = match ty::impl_trait_ref(tcx, impl_id) {
+            None => {
+                tcx.sess.bug("statically resolved overloaded call impl \
+                              didn't implement a trait?!")
+            }
+            Some(ref trait_ref) => (*trait_ref).clone(),
+        };
+        OverloadedCallType::from_trait_id(tcx, trait_ref.def_id)
+    }
+
+    fn from_method_origin(tcx: &ty::ctxt, origin: &MethodOrigin)
+                          -> OverloadedCallType {
+        match *origin {
+            MethodStatic(def_id) => {
+                OverloadedCallType::from_method_id(tcx, def_id)
+            }
+            MethodParam(ref method_param) => {
+                OverloadedCallType::from_trait_id(tcx, method_param.trait_id)
+            }
+            MethodObject(ref method_object) => {
+                OverloadedCallType::from_trait_id(tcx, method_object.trait_id)
+            }
+        }
+    }
+}
+
 ///////////////////////////////////////////////////////////////////////////
 // The ExprUseVisitor type
 //
@@ -413,19 +482,37 @@ impl<'d,'t,TYPER:mc::Typer> ExprUseVisitor<'d,'t,TYPER> {
                 }
             }
             _ => {
-                match self.tcx()
-                          .method_map
-                          .borrow()
-                          .find(&MethodCall::expr(call.id)) {
-                    Some(_) => {
-                        // FIXME(#14774, pcwalton): Implement this.
+                let overloaded_call_type =
+                    match self.tcx()
+                              .method_map
+                              .borrow()
+                              .find(&MethodCall::expr(call.id)) {
+                    Some(ref method_callee) => {
+                        OverloadedCallType::from_method_origin(
+                            self.tcx(),
+                            &method_callee.origin)
                     }
                     None => {
                         self.tcx().sess.span_bug(
                             callee.span,
                             format!("unexpected callee type {}",
-                                    callee_ty.repr(self.tcx())).as_slice());
+                                    callee_ty.repr(self.tcx())).as_slice())
+                    }
+                };
+                match overloaded_call_type {
+                    FnMutOverloadedCall => {
+                        self.borrow_expr(callee,
+                                         ty::ReScope(call.id),
+                                         ty::MutBorrow,
+                                         ClosureInvocation);
+                    }
+                    FnOverloadedCall => {
+                        self.borrow_expr(callee,
+                                         ty::ReScope(call.id),
+                                         ty::ImmBorrow,
+                                         ClosureInvocation);
                     }
+                    FnOnceOverloadedCall => self.consume_expr(callee),
                 }
             }
         }
diff --git a/src/test/compile-fail/borrowck-overloaded-call.rs b/src/test/compile-fail/borrowck-overloaded-call.rs
new file mode 100644
index 00000000000..349a20313fa
--- /dev/null
+++ b/src/test/compile-fail/borrowck-overloaded-call.rs
@@ -0,0 +1,74 @@
+// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![feature(overloaded_calls)]
+
+use std::ops::{Fn, FnMut, FnOnce};
+
+struct SFn {
+    x: int,
+    y: int,
+}
+
+impl Fn<(int,),int> for SFn {
+    fn call(&self, (z,): (int,)) -> int {
+        self.x * self.y * z
+    }
+}
+
+struct SFnMut {
+    x: int,
+    y: int,
+}
+
+impl FnMut<(int,),int> for SFnMut {
+    fn call_mut(&mut self, (z,): (int,)) -> int {
+        self.x * self.y * z
+    }
+}
+
+struct SFnOnce {
+    x: String,
+}
+
+impl FnOnce<(String,),uint> for SFnOnce {
+    fn call_once(self, (z,): (String,)) -> uint {
+        self.x.len() + z.len()
+    }
+}
+
+fn f() {
+    let mut s = SFn {
+        x: 1,
+        y: 2,
+    };
+    let sp = &mut s;
+    s(3);   //~ ERROR cannot borrow `s` as immutable because it is also borrowed as mutable
+    //~^ ERROR cannot borrow `s` as immutable because it is also borrowed as mutable
+}
+
+fn g() {
+    let s = SFnMut {
+        x: 1,
+        y: 2,
+    };
+    s(3);   //~ ERROR cannot borrow immutable local variable `s` as mutable
+}
+
+fn h() {
+    let s = SFnOnce {
+        x: "hello".to_string(),
+    };
+    s(" world".to_string());
+    s(" world".to_string());    //~ ERROR use of moved value: `s`
+}
+
+fn main() {}
+