about summary refs log tree commit diff
path: root/src/comp
diff options
context:
space:
mode:
authorMarijn Haverbeke <marijnh@gmail.com>2012-02-21 17:02:02 +0100
committerMarijn Haverbeke <marijnh@gmail.com>2012-02-21 17:08:14 +0100
commit9d20ed7bf97e533e0cc5d7be2c3ec5d5dfd30e98 (patch)
tree3d71c5f084d1e84374cf532ad9dfe257c899ba96 /src/comp
parentff927f18f5b118bd26ba8a0826b317c27daba70d (diff)
downloadrust-9d20ed7bf97e533e0cc5d7be2c3ec5d5dfd30e98.tar.gz
rust-9d20ed7bf97e533e0cc5d7be2c3ec5d5dfd30e98.zip
Clean up unification code
Diffstat (limited to 'src/comp')
-rw-r--r--src/comp/middle/ty.rs444
1 files changed, 116 insertions, 328 deletions
diff --git a/src/comp/middle/ty.rs b/src/comp/middle/ty.rs
index ff80ecd7e74..6b0c6f5d1dc 100644
--- a/src/comp/middle/ty.rs
+++ b/src/comp/middle/ty.rs
@@ -258,6 +258,7 @@ enum type_err {
     terr_mismatch,
     terr_ret_style_mismatch(ast::ret_style, ast::ret_style),
     terr_box_mutability,
+    terr_ptr_mutability,
     terr_vec_mutability,
     terr_tuple_size(uint, uint),
     terr_record_size(uint, uint),
@@ -1557,26 +1558,9 @@ mod unify {
         }
     }
 
-    fn record_var_binding_for_expected(
-        cx: @uctxt, key: int, typ: t, variance: variance) -> result {
-        record_var_binding(
-            cx, key, typ, variance_transform(variance, covariant))
-    }
-
-    fn record_var_binding_for_actual(
-        cx: @uctxt, key: int, typ: t, variance: variance) -> result {
-        // Unifying in 'the other direction' so flip the variance
-        record_var_binding(
-            cx, key, typ, variance_transform(variance, contravariant))
-    }
-
-    fn record_var_binding(
-        cx: @uctxt, key: int, typ: t, variance: variance) -> result {
-
-        let vb = alt cx.st { in_bindings(vb) { vb }
-            _ { cx.tcx.sess.bug("Someone forgot to document an invariant \
-                 in record_var_binding");  }
-        };
+    fn record_var_binding(cx: @uctxt, key: int, typ: t, variance: variance)
+        -> result {
+        let vb = alt check cx.st { in_bindings(vb) { vb } };
         ufind::grow(vb.sets, (key as uint) + 1u);
         let root = ufind::find(vb.sets, key as uint);
         let result_type = typ;
@@ -1589,8 +1573,8 @@ mod unify {
           }
           none {/* fall through */ }
         }
-        smallintmap::insert::<t>(vb.types, root, result_type);
-        ret ures_ok(typ);
+        smallintmap::insert(vb.types, root, result_type);
+        ret ures_ok(mk_var(cx.tcx, key));
     }
 
     // Simple structural type comparison.
@@ -1837,7 +1821,7 @@ mod unify {
     }
 
     fn unify_tps(cx: @uctxt, expected_tps: [t], actual_tps: [t],
-                 variance: variance, finish: fn([t]) -> result) -> result {
+                 variance: variance, finish: fn([t]) -> t) -> result {
         let result_tps = [], i = 0u;
         for exp in expected_tps {
             let act = actual_tps[i];
@@ -1848,345 +1832,148 @@ mod unify {
               _ { ret result; }
             }
         }
-        finish(result_tps)
+        ures_ok(finish(result_tps))
+    }
+    fn unify_mt(cx: @uctxt, e_mt: mt, a_mt: mt, variance: variance,
+                mut_err: type_err, finish: fn(ctxt, mt) -> t) -> result {
+        alt unify_mut(e_mt.mutbl, a_mt.mutbl, variance) {
+          none { ures_err(mut_err) }
+          some((mutt, var)) {
+            alt unify_step(cx, e_mt.ty, a_mt.ty, var) {
+              ures_ok(result_sub) {
+                ures_ok(finish(cx.tcx, {ty: result_sub, mutbl: mutt}))
+              }
+              err { err }
+            }
+          }
+        }
     }
+
     fn unify_step(cx: @uctxt, expected: t, actual: t,
                   variance: variance) -> result {
-        // FIXME: rewrite this using tuple pattern matching when available, to
-        // avoid all this rightward drift and spikiness.
-        // NOTE: we have tuple matching now, but that involves copying the
-        // matched elements into a tuple first, which is expensive, since sty
-        // holds vectors, which are currently unique
-
         // Fast path.
         if expected == actual { ret ures_ok(expected); }
 
-        // Stage 1: Handle the cases in which one side or another is a type
-        // variable
-
-        alt get(actual).struct {
-          // If the RHS is a variable type, then just do the
-          // appropriate binding.
-          ty_var(actual_id) {
-            let actual_n = actual_id as uint;
-            alt get(expected).struct {
-              ty_var(expected_id) {
-                let expected_n = expected_id as uint;
-                alt union(cx, expected_n, actual_n, variance) {
-                  unres_ok {/* fall through */ }
-                  unres_err(t_e) { ret ures_err(t_e); }
-                }
-              }
-              _ {
-                // Just bind the type variable to the expected type.
-                alt record_var_binding_for_actual(
-                    cx, actual_id, expected, variance) {
-                  ures_ok(_) {/* fall through */ }
-                  rs { ret rs; }
-                }
-              }
+        alt (get(expected).struct, get(actual).struct) {
+          (ty_var(e_id), ty_var(a_id)) {
+            alt union(cx, e_id as uint, a_id as uint, variance) {
+              unres_ok { ures_ok(actual) }
+              unres_err(err) { ures_err(err) }
             }
-            ret ures_ok(mk_var(cx.tcx, actual_id));
           }
-          _ {/* empty */ }
-        }
-        alt get(expected).struct {
-          ty_var(expected_id) {
-            // Add a binding. (`actual` can't actually be a var here.)
-            alt record_var_binding_for_expected(
-                cx, expected_id, actual,
-                variance) {
-              ures_ok(_) {/* fall through */ }
-              rs { ret rs; }
-            }
-            ret ures_ok(mk_var(cx.tcx, expected_id));
+          (_, ty_var(a_id)) {
+            let v = variance_transform(variance, contravariant);
+            record_var_binding(cx, a_id, expected, v)
           }
-          _ {/* fall through */ }
-        }
-        // Stage 2: Handle all other cases.
-
-        alt get(actual).struct {
-          ty_bot { ret ures_ok(expected); }
-          _ {/* fall through */ }
-        }
-        alt get(expected).struct {
-          ty_nil { ret struct_cmp(cx, expected, actual); }
-          // _|_ unifies with anything
-          ty_bot {
-            ret ures_ok(actual);
+          (ty_var(e_id), _) {
+            let v = variance_transform(variance, covariant);
+            record_var_binding(cx, e_id, actual, v)
           }
-          ty_bool | ty_int(_) | ty_uint(_) | ty_float(_) |
-          ty_str | ty_send_type {
-            ret struct_cmp(cx, expected, actual);
+          (_, ty_bot) { ures_ok(expected) }
+          (ty_bot, _) { ures_ok(actual) }
+          (ty_nil, _) | (ty_bool, _) | (ty_int(_), _) | (ty_uint(_), _) |
+          (ty_float(_), _) | (ty_str, _) | (ty_send_type, _) {
+            struct_cmp(cx, expected, actual)
           }
-          ty_param(expected_n, _) {
-            alt get(actual).struct {
-              ty_param(actual_n, _) if expected_n == actual_n {
-                ret ures_ok(expected);
-              }
-              _ { ret ures_err(terr_mismatch); }
-            }
+          (ty_param(e_n, _), ty_param(a_n, _)) if e_n == a_n {
+            ures_ok(expected)
           }
-          ty_enum(expected_id, expected_tps) {
-            alt get(actual).struct {
-              ty_enum(actual_id, actual_tps) {
-                if expected_id != actual_id {
-                    ret ures_err(terr_mismatch);
-                }
-                ret unify_tps(cx, expected_tps, actual_tps, variance, {|tps|
-                    ures_ok(mk_enum(cx.tcx, expected_id, tps))
-                });
-              }
-              _ {/* fall through */ }
-            }
-            ret ures_err(terr_mismatch);
+          (ty_enum(e_id, e_tps), ty_enum(a_id, a_tps)) if e_id == a_id {
+            unify_tps(cx, e_tps, a_tps, variance, {|tps|
+                mk_enum(cx.tcx, e_id, tps)
+            })
           }
-          ty_iface(expected_id, expected_tps) {
-            alt get(actual).struct {
-              ty_iface(actual_id, actual_tps) {
-                if expected_id != actual_id {
-                    ret ures_err(terr_mismatch);
-                }
-                ret unify_tps(cx, expected_tps, actual_tps, variance, {|tps|
-                    ures_ok(mk_iface(cx.tcx, expected_id, tps))
-                });
-              }
-              _ {}
-            }
-            ret ures_err(terr_mismatch);
+          (ty_iface(e_id, e_tps), ty_iface(a_id, a_tps)) if e_id == a_id {
+            unify_tps(cx, e_tps, a_tps, variance, {|tps|
+                mk_iface(cx.tcx, e_id, tps)
+            })
           }
-          ty_box(expected_mt) {
-            alt get(actual).struct {
-              ty_box(actual_mt) {
-                let (mutt, var) = alt unify_mut(
-                    expected_mt.mutbl, actual_mt.mutbl, variance) {
-                  none { ret ures_err(terr_box_mutability); }
-                  some(mv) { mv }
-                };
-                let result = unify_step(
-                    cx, expected_mt.ty, actual_mt.ty, var);
-                alt result {
-                  ures_ok(result_sub) {
-                    let mt = {ty: result_sub, mutbl: mutt};
-                    ret ures_ok(mk_box(cx.tcx, mt));
-                  }
-                  _ { ret result; }
-                }
-              }
-              _ { ret ures_err(terr_mismatch); }
-            }
+          (ty_class(e_id, e_tps), ty_class(a_id, a_tps)) if e_id == a_id {
+            unify_tps(cx, e_tps, a_tps, variance, {|tps|
+                mk_class(cx.tcx, e_id, tps)
+            })
           }
-          ty_uniq(expected_mt) {
-            alt get(actual).struct {
-              ty_uniq(actual_mt) {
-                let (mutt, var) = alt unify_mut(
-                    expected_mt.mutbl, actual_mt.mutbl, variance) {
-                  none { ret ures_err(terr_box_mutability); }
-                  some(mv) { mv }
-                };
-                let result = unify_step(
-                    cx, expected_mt.ty, actual_mt.ty, var);
-                alt result {
-                  ures_ok(result_mt) {
-                    let mt = {ty: result_mt, mutbl: mutt};
-                    ret ures_ok(mk_uniq(cx.tcx, mt));
-                  }
-                  _ { ret result; }
-                }
-              }
-              _ { ret ures_err(terr_mismatch); }
-            }
+          (ty_box(e_mt), ty_box(a_mt)) {
+            unify_mt(cx, e_mt, a_mt, variance, terr_box_mutability, mk_box)
           }
-          ty_vec(expected_mt) {
-            alt get(actual).struct {
-              ty_vec(actual_mt) {
-                let (mutt, var) = alt unify_mut(
-                    expected_mt.mutbl, actual_mt.mutbl, variance) {
-                  none { ret ures_err(terr_vec_mutability); }
-                  some(mv) { mv }
-                };
-                let result = unify_step(
-                    cx, expected_mt.ty, actual_mt.ty, var);
-                alt result {
-                  ures_ok(result_sub) {
-                    let mt = {ty: result_sub, mutbl: mutt};
-                    ret ures_ok(mk_vec(cx.tcx, mt));
-                  }
-                  _ { ret result; }
-                }
-              }
-              _ { ret ures_err(terr_mismatch); }
-            }
+          (ty_uniq(e_mt), ty_uniq(a_mt)) {
+            unify_mt(cx, e_mt, a_mt, variance, terr_box_mutability, mk_uniq)
           }
-          ty_ptr(expected_mt) {
-            alt get(actual).struct {
-              ty_ptr(actual_mt) {
-                let (mutt, var) = alt unify_mut(
-                    expected_mt.mutbl, actual_mt.mutbl, variance) {
-                  none { ret ures_err(terr_vec_mutability); }
-                  some(mv) { mv }
-                };
-                let result = unify_step(
-                    cx, expected_mt.ty, actual_mt.ty, var);
-                alt result {
-                  ures_ok(result_sub) {
-                    let mt = {ty: result_sub, mutbl: mutt};
-                    ret ures_ok(mk_ptr(cx.tcx, mt));
-                  }
-                  _ { ret result; }
-                }
-              }
-              _ { ret ures_err(terr_mismatch); }
-            }
+          (ty_vec(e_mt), ty_vec(a_mt)) {
+            unify_mt(cx, e_mt, a_mt, variance, terr_vec_mutability, mk_vec)
           }
-          ty_res(ex_id, ex_inner, ex_tps) {
-            alt get(actual).struct {
-              ty_res(act_id, act_inner, act_tps) {
-                if ex_id.crate != act_id.crate || ex_id.node != act_id.node {
-                    ret ures_err(terr_mismatch);
-                }
-                let result = unify_step(
-                    cx, ex_inner, act_inner, variance);
-                alt result {
-                  ures_ok(res_inner) {
-                    let i = 0u;
-                    let res_tps = [];
-                    for ex_tp: t in ex_tps {
-                        let result = unify_step(
-                            cx, ex_tp, act_tps[i], variance);
-                        alt result {
-                          ures_ok(rty) { res_tps += [rty]; }
-                          _ { ret result; }
-                        }
-                        i += 1u;
-                    }
-                    ret ures_ok(mk_res(cx.tcx, act_id, res_inner, res_tps));
-                  }
-                  _ { ret result; }
-                }
+          (ty_ptr(e_mt), ty_ptr(a_mt)) {
+            unify_mt(cx, e_mt, a_mt, variance, terr_ptr_mutability, mk_ptr)
+          }
+          (ty_res(e_id, e_inner, e_tps), ty_res(a_id, a_inner, a_tps))
+          if e_id == a_id {
+            alt unify_step(cx, e_inner, a_inner, variance) {
+              ures_ok(res_inner) {
+                unify_tps(cx, e_tps, a_tps, variance, {|tps|
+                    mk_res(cx.tcx, a_id, res_inner, tps)
+                })
               }
-              _ { ret ures_err(terr_mismatch); }
+              err { err }
             }
           }
-          ty_rec(expected_fields) {
-            alt get(actual).struct {
-              ty_rec(actual_fields) {
-                let expected_len = vec::len::<field>(expected_fields);
-                let actual_len = vec::len::<field>(actual_fields);
-                if expected_len != actual_len {
-                    let err = terr_record_size(expected_len, actual_len);
-                    ret ures_err(err);
+          (ty_rec(e_fields), ty_rec(a_fields)) {
+            let e_len = e_fields.len(), a_len = a_fields.len();
+            if e_len != a_len {
+                ret ures_err(terr_record_size(e_len, a_len));
+            }
+            let result_fields = [], i = 0u;
+            while i < a_len {
+                let e_field = e_fields[i], a_field = a_fields[i];
+                if e_field.ident != a_field.ident {
+                    ret ures_err(terr_record_fields(e_field.ident,
+                                                    a_field.ident));
                 }
-
-                let result_fields = [], i = 0u;
-                while i < actual_len {
-                    let expected_field = expected_fields[i],
-                        actual_field = actual_fields[i];
-                    let u_mut = unify_mut(expected_field.mt.mutbl,
-                                          actual_field.mt.mutbl,
-                                          variance);
-                    let (mutt, var) = alt u_mut {
-                      none { ret ures_err(terr_record_mutability); }
-                      some(mv) { mv }
-                    };
-                    if !str::eq(expected_field.ident, actual_field.ident) {
-                        let err =
-                            terr_record_fields(expected_field.ident,
-                                               actual_field.ident);
-                        ret ures_err(err);
-                    }
-                    let result =
-                        unify_step(cx, expected_field.mt.ty,
-                                   actual_field.mt.ty, var);
-                    alt result {
-                      ures_ok(rty) {
-                        let mt = {ty: rty, mutbl: mutt};
-                        result_fields += [{mt: mt with expected_field}];
-                      }
-                      _ { ret result; }
-                    }
-                    i += 1u;
+                alt unify_mt(cx, e_field.mt, a_field.mt, variance,
+                             terr_record_mutability, {|cx, mt|
+                    result_fields += [{mt: mt with e_field}];
+                    mk_nil(cx)
+                }) {
+                  ures_ok(_) {}
+                  err { ret err; }
                 }
-                ret ures_ok(mk_rec(cx.tcx, result_fields));
-              }
-              _ { ret ures_err(terr_mismatch); }
+                i += 1u;
             }
+            ures_ok(mk_rec(cx.tcx, result_fields))
           }
-          ty_tup(expected_elems) {
-            alt get(actual).struct {
-              ty_tup(actual_elems) {
-                let expected_len = vec::len(expected_elems);
-                let actual_len = vec::len(actual_elems);
-                if expected_len != actual_len {
-                    let err = terr_tuple_size(expected_len, actual_len);
-                    ret ures_err(err);
+          (ty_tup(e_elems), ty_tup(a_elems)) {
+            let e_len = e_elems.len(), a_len = a_elems.len();
+            if e_len != a_len { ret ures_err(terr_tuple_size(e_len, a_len)); }
+            let result_elems = [], i = 0u;
+            while i < a_len {
+                alt unify_step(cx, e_elems[i], a_elems[i], variance) {
+                  ures_ok(rty) { result_elems += [rty]; }
+                  err { ret err; }
                 }
-
-                let result_elems = [], i = 0u;
-                while i < actual_len {
-                    alt unify_step(cx, expected_elems[i], actual_elems[i],
-                                   variance) {
-                      ures_ok(rty) { result_elems += [rty]; }
-                      r { ret r; }
-                    }
-                    i += 1u;
-                }
-                ret ures_ok(mk_tup(cx.tcx, result_elems));
-              }
-              _ { ret ures_err(terr_mismatch); }
+                i += 1u;
             }
+            ures_ok(mk_tup(cx.tcx, result_elems))
           }
-          ty_fn(expected_f) {
-            alt get(actual).struct {
-              ty_fn(actual_f) {
-                ret unify_fn(cx, expected_f, actual_f, variance);
-              }
-              _ { ret ures_err(terr_mismatch); }
-            }
+          (ty_fn(e_fty), ty_fn(a_fty)) {
+            unify_fn(cx, e_fty, a_fty, variance)
           }
-          ty_constr(expected_t, expected_constrs) {
-
+          (ty_constr(e_t, e_constrs), ty_constr(a_t, a_constrs)) {
             // unify the base types...
-            alt get(actual).struct {
-              ty_constr(actual_t, actual_constrs) {
-                let rslt = unify_step(
-                    cx, expected_t, actual_t, variance);
-                alt rslt {
-                  ures_ok(rty) {
-                    // FIXME: probably too restrictive --
-                    // requires the constraints to be
-                    // syntactically equal
-                    ret unify_constrs(expected, expected_constrs,
-                                      actual_constrs);
-                  }
-                  _ { ret rslt; }
-                }
-              }
-              _ {
-                // If the actual type is *not* a constrained type,
-                // then we go ahead and just ignore the constraints on
-                // the expected type. typestate handles the rest.
-                ret unify_step(
-                    cx, expected_t, actual, variance);
+            alt unify_step(cx, e_t, a_t, variance) {
+              ures_ok(rty) {
+                // FIXME: probably too restrictive --
+                // requires the constraints to be syntactically equal
+                unify_constrs(expected, e_constrs, a_constrs)
               }
+              err { err }
             }
           }
-          ty_class(expected_class, expected_tys) {
-              alt get(actual).struct {
-                ty_class(actual_class, actual_tys) {
-                    if expected_class != actual_class {
-                        ret ures_err(terr_mismatch);
-                    }
-                    ret unify_tps(cx, expected_tys, actual_tys, variance,
-                           {|tps|
-                            ures_ok(mk_class(cx.tcx, expected_class, tps))});
-                }
-                _ {
-                    ret ures_err(terr_mismatch);
-                }
-              }
+          (ty_constr(e_t, _), _) {
+            // If the actual type is *not* a constrained type,
+            // then we go ahead and just ignore the constraints on
+            // the expected type. typestate handles the rest.
+            unify_step(cx, e_t, actual, variance)
           }
-          _ { cx.tcx.sess.bug("unify: unexpected type"); }
+          _ { ures_err(terr_mismatch) }
         }
     }
     fn unify(expected: t, actual: t, st: unify_style,
@@ -2293,6 +2080,7 @@ fn type_err_to_str(err: type_err) -> str {
       }
       terr_box_mutability { ret "boxed values differ in mutability"; }
       terr_vec_mutability { ret "vectors differ in mutability"; }
+      terr_ptr_mutability { ret "pointers differ in mutability"; }
       terr_tuple_size(e_sz, a_sz) {
         ret "expected a tuple with " + uint::to_str(e_sz, 10u) +
                 " elements but found one with " + uint::to_str(a_sz, 10u) +