about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2024-12-28 19:19:06 +0100
committerLukas Wirth <lukastw97@gmail.com>2024-12-28 19:54:22 +0100
commitd66a337658a4e175380c1ff59a73375b76237b9f (patch)
tree86ca898be351cb26854112a1cea88225edb3cc70 /src
parent0e50c3c81be8bc1c80a7a5ed833ff3fc98e3257f (diff)
downloadrust-d66a337658a4e175380c1ff59a73375b76237b9f.tar.gz
rust-d66a337658a4e175380c1ff59a73375b76237b9f.zip
Get rid of constrain and solve steps
Diffstat (limited to 'src')
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/generics.rs8
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/variance.rs131
2 files changed, 44 insertions, 95 deletions
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs
index e7a2721afee..fe7541d2374 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/generics.rs
@@ -132,14 +132,6 @@ impl Generics {
         self.params.len()
     }
 
-    pub(crate) fn len_self_lifetimes(&self) -> usize {
-        self.params.len_lifetimes()
-    }
-
-    pub(crate) fn has_trait_self(&self) -> bool {
-        self.params.trait_self_param().is_some()
-    }
-
     /// (parent total, self param, type params, const params, impl trait list, lifetimes)
     pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) {
         let mut self_param = false;
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs
index ca16e986af5..0cce1aec2b4 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs
@@ -39,19 +39,9 @@ pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Ar
     if count == 0 {
         return None;
     }
-    let mut ctxt = Context {
-        def,
-        has_trait_self: generics.parent_generics().map_or(false, |it| it.has_trait_self()),
-        len_self: generics.len_self(),
-        len_self_lifetimes: generics.len_self_lifetimes(),
-        generics,
-        constraints: Vec::new(),
-        db,
-    };
+    let variances = Context { generics, variances: vec![Variance::Bivariant; count], db }.solve();
 
-    ctxt.build_constraints_for_item();
-    let res = ctxt.solve();
-    res.is_empty().not().then(|| Arc::from_iter(res))
+    variances.is_empty().not().then(|| Arc::from_iter(variances))
 }
 
 pub(crate) fn variances_of_cycle(
@@ -172,25 +162,14 @@ struct InferredIndex(usize);
 
 struct Context<'db> {
     db: &'db dyn HirDatabase,
-    def: GenericDefId,
-    has_trait_self: bool,
-    len_self: usize,
-    len_self_lifetimes: usize,
     generics: Generics,
-    constraints: Vec<Constraint>,
-}
-
-/// Declares that the variable `decl_id` appears in a location with
-/// variance `variance`.
-#[derive(Clone)]
-struct Constraint {
-    inferred: InferredIndex,
-    variance: Variance,
+    variances: Vec<Variance>,
 }
 
 impl Context<'_> {
-    fn build_constraints_for_item(&mut self) {
-        match self.def {
+    fn solve(mut self) -> Vec<Variance> {
+        tracing::debug!("solve(generics={:?})", self.generics);
+        match self.generics.def() {
             GenericDefId::AdtId(adt) => {
                 let db = self.db;
                 let mut add_constraints_from_variant = |variant| {
@@ -225,6 +204,26 @@ impl Context<'_> {
             }
             _ => {}
         }
+        let mut variances = self.variances;
+
+        // Const parameters are always invariant.
+        // Make all const parameters invariant.
+        for (idx, param) in self.generics.iter_id().enumerate() {
+            if let GenericParamId::ConstParamId(_) = param {
+                variances[idx] = Variance::Invariant;
+            }
+        }
+
+        // Functions are permitted to have unused generic parameters: make those invariant.
+        if let GenericDefId::FunctionId(_) = self.generics.def() {
+            for variance in &mut variances {
+                if *variance == Variance::Bivariant {
+                    *variance = Variance::Invariant;
+                }
+            }
+        }
+
+        variances
     }
 
     fn contravariant(&mut self, variance: Variance) -> Variance {
@@ -353,14 +352,8 @@ impl Context<'_> {
             // Chalk has no params, so use placeholders for now?
             TyKind::Placeholder(index) => {
                 let idx = crate::from_placeholder_idx(self.db, *index);
-                let index = idx.local_id.into_raw().into_u32() as usize + self.len_self_lifetimes;
-                let inferred = if idx.parent == self.def {
-                    InferredIndex(self.has_trait_self as usize + index)
-                } else {
-                    InferredIndex(self.len_self + index)
-                };
-                tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance);
-                self.constraints.push(Constraint { inferred, variance });
+                let inferred = InferredIndex(self.generics.type_or_const_param_idx(idx).unwrap());
+                self.constrain(inferred, variance);
             }
             TyKind::Function(f) => {
                 self.add_constraints_from_sig(f, variance);
@@ -396,7 +389,7 @@ impl Context<'_> {
         if args.is_empty() {
             return;
         }
-        if def_id == self.def {
+        if def_id == self.generics.def() {
             // HACK: Workaround for the trivial cycle salsa case (see
             // recursive_one_bivariant_more_non_bivariant_params test)
             let variance_i = variance.xform(Variance::Bivariant);
@@ -463,18 +456,17 @@ impl Context<'_> {
     /// Adds constraints appropriate for a region appearing in a
     /// context with ambient variance `variance`
     fn add_constraints_from_region(&mut self, region: &Lifetime, variance: Variance) {
+        tracing::debug!(
+            "add_constraints_from_region(region={:?}, variance={:?})",
+            region,
+            variance
+        );
         match region.data(Interner) {
             // FIXME: chalk has no params?
             LifetimeData::Placeholder(index) => {
                 let idx = crate::lt_from_placeholder_idx(self.db, *index);
-                let index = idx.local_id.into_raw().into_u32() as usize;
-                let inferred = if idx.parent == self.def {
-                    InferredIndex(index)
-                } else {
-                    InferredIndex(self.has_trait_self as usize + self.len_self + index)
-                };
-                tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance);
-                self.constraints.push(Constraint { inferred, variance: variance.clone() });
+                let inferred = InferredIndex(self.generics.lifetime_idx(idx).unwrap());
+                self.constrain(inferred, variance);
             }
             LifetimeData::Static => {}
 
@@ -513,50 +505,15 @@ impl Context<'_> {
             }
         }
     }
-}
-
-impl Context<'_> {
-    fn solve(self) -> Vec<Variance> {
-        let mut solutions = vec![Variance::Bivariant; self.generics.len()];
-        // Propagate constraints until a fixed point is reached. Note
-        // that the maximum number of iterations is 2C where C is the
-        // number of constraints (each variable can change values at most
-        // twice). Since number of constraints is linear in size of the
-        // input, so is the inference process.
-        let mut changed = true;
-        while changed {
-            changed = false;
-
-            for constraint in &self.constraints {
-                let &Constraint { inferred, variance } = constraint;
-                let InferredIndex(inferred) = inferred;
-                let old_value = solutions[inferred];
-                let new_value = variance.glb(old_value);
-                if old_value != new_value {
-                    solutions[inferred] = new_value;
-                    changed = true;
-                }
-            }
-        }
 
-        // Const parameters are always invariant.
-        // Make all const parameters invariant.
-        for (idx, param) in self.generics.iter_id().enumerate() {
-            if let GenericParamId::ConstParamId(_) = param {
-                solutions[idx] = Variance::Invariant;
-            }
-        }
-
-        // Functions are permitted to have unused generic parameters: make those invariant.
-        if let GenericDefId::FunctionId(_) = self.def {
-            for variance in &mut solutions {
-                if *variance == Variance::Bivariant {
-                    *variance = Variance::Invariant;
-                }
-            }
-        }
-
-        solutions
+    fn constrain(&mut self, inferred: InferredIndex, variance: Variance) {
+        tracing::debug!(
+            "constrain(index={:?}, variance={:?}, to={:?})",
+            inferred,
+            self.variances[inferred.0],
+            variance
+        );
+        self.variances[inferred.0] = self.variances[inferred.0].glb(variance);
     }
 }