about summary refs log tree commit diff
diff options
context:
space:
mode:
authorroife <roifewu@gmail.com>2024-06-08 03:10:27 +0800
committerroife <roifewu@gmail.com>2024-06-11 21:02:13 +0800
commit80a2ac568108dbe1af674a7ac6f48adf74b0619e (patch)
tree814f6d79dcdc3e7ce480fb1f6bd2785f8a9c0edc
parent4cf0490c213628d25909a91e626b8bada568081d (diff)
downloadrust-80a2ac568108dbe1af674a7ac6f48adf74b0619e.tar.gz
rust-80a2ac568108dbe1af674a7ac6f48adf74b0619e.zip
internal: simplify and refactor write_where_clause
-rw-r--r--src/tools/rust-analyzer/Cargo.lock1
-rw-r--r--src/tools/rust-analyzer/crates/hir/Cargo.toml1
-rw-r--r--src/tools/rust-analyzer/crates/hir/src/display.rs185
3 files changed, 98 insertions, 89 deletions
diff --git a/src/tools/rust-analyzer/Cargo.lock b/src/tools/rust-analyzer/Cargo.lock
index efb313cd3c8..57d43dad3fd 100644
--- a/src/tools/rust-analyzer/Cargo.lock
+++ b/src/tools/rust-analyzer/Cargo.lock
@@ -512,6 +512,7 @@ dependencies = [
  "hir-def",
  "hir-expand",
  "hir-ty",
+ "intern",
  "itertools",
  "once_cell",
  "rustc-hash",
diff --git a/src/tools/rust-analyzer/crates/hir/Cargo.toml b/src/tools/rust-analyzer/crates/hir/Cargo.toml
index 6d7ecd1e50b..edf26a07a74 100644
--- a/src/tools/rust-analyzer/crates/hir/Cargo.toml
+++ b/src/tools/rust-analyzer/crates/hir/Cargo.toml
@@ -27,6 +27,7 @@ cfg.workspace = true
 hir-def.workspace = true
 hir-expand.workspace = true
 hir-ty.workspace = true
+intern.workspace = true
 stdx.workspace = true
 syntax.workspace = true
 tt.workspace = true
diff --git a/src/tools/rust-analyzer/crates/hir/src/display.rs b/src/tools/rust-analyzer/crates/hir/src/display.rs
index c276e87786d..1d8bbf9de6d 100644
--- a/src/tools/rust-analyzer/crates/hir/src/display.rs
+++ b/src/tools/rust-analyzer/crates/hir/src/display.rs
@@ -3,7 +3,8 @@ use either::Either;
 use hir_def::{
     data::adt::{StructKind, VariantData},
     generics::{
-        TypeOrConstParamData, TypeParamProvenance, WherePredicate, WherePredicateTypeTarget,
+        GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate,
+        WherePredicateTypeTarget,
     },
     lang_item::LangItem,
     type_ref::{TypeBound, TypeRef},
@@ -16,6 +17,8 @@ use hir_ty::{
     },
     AliasEq, AliasTy, Interner, ProjectionTyExt, TraitRefExt, TyKind, WhereClause,
 };
+use intern::Interned;
+use itertools::Itertools;
 
 use crate::{
     Adt, AsAssocItem, AssocItem, AssocItemContainer, Const, ConstParam, Enum, ExternCrateDecl,
@@ -30,9 +33,13 @@ impl HirDisplay for Function {
         let data = db.function_data(self.id);
         let container = self.as_assoc_item(db).map(|it| it.container(db));
         let mut module = self.module(db);
-        if let Some(AssocItemContainer::Impl(_)) = container {
-            // Block-local impls are "hoisted" to the nearest (non-block) module.
-            module = module.nearest_non_block_module(db);
+
+        match container {
+            Some(AssocItemContainer::Impl(_)) => {
+                // Block-local impls are "hoisted" to the nearest (non-block) module.
+                module = module.nearest_non_block_module(db);
+            }
+            _ => {}
         }
         let module_id = module.id;
         write_visibility(module_id, self.visibility(db), f)?;
@@ -555,101 +562,96 @@ fn write_where_clause(
 ) -> Result<bool, HirDisplayError> {
     let params = f.db.generic_params(def);
 
-    // unnamed type targets are displayed inline with the argument itself, e.g. `f: impl Y`.
-    let is_unnamed_type_target = |target: &WherePredicateTypeTarget| match target {
-        WherePredicateTypeTarget::TypeRef(_) => false,
-        WherePredicateTypeTarget::TypeOrConstParam(id) => {
-            params.type_or_consts[*id].name().is_none()
-        }
+    let no_displayable_pred = |params: &Interned<GenericParams>| {
+        params.where_predicates.iter().all(|pred| {
+            matches!(
+                pred,
+                WherePredicate::TypeBound { target: WherePredicateTypeTarget::TypeOrConstParam(id), .. }
+                if params.type_or_consts[*id].name().is_none()
+            )
+        })
     };
 
-    let has_displayable_predicate = params
-        .where_predicates
-        .iter()
-        .any(|pred| {
-            !matches!(pred, WherePredicate::TypeBound { target, .. } if is_unnamed_type_target(target))
-        });
-
-    if !has_displayable_predicate {
+    if no_displayable_pred(&params) {
         return Ok(false);
     }
 
+    f.write_str("\nwhere")?;
+    write_where_predicates(&params, f)?;
+
+    Ok(true)
+}
+
+fn write_where_predicates(
+    params: &Interned<GenericParams>,
+    f: &mut HirFormatter<'_>,
+) -> Result<(), HirDisplayError> {
+    use WherePredicate::*;
+
+    // unnamed type targets are displayed inline with the argument itself, e.g. `f: impl Y`.
+    let is_unnamed_type_target =
+        |params: &Interned<GenericParams>, target: &WherePredicateTypeTarget| {
+            matches!(target,
+                WherePredicateTypeTarget::TypeOrConstParam(id) if params.type_or_consts[*id].name().is_none()
+            )
+        };
+
     let write_target = |target: &WherePredicateTypeTarget, f: &mut HirFormatter<'_>| match target {
         WherePredicateTypeTarget::TypeRef(ty) => ty.hir_fmt(f),
-        WherePredicateTypeTarget::TypeOrConstParam(id) => {
-            match &params.type_or_consts[*id].name() {
-                Some(name) => write!(f, "{}", name.display(f.db.upcast())),
-                None => f.write_str("{unnamed}"),
-            }
-        }
+        WherePredicateTypeTarget::TypeOrConstParam(id) => match params.type_or_consts[*id].name() {
+            Some(name) => write!(f, "{}", name.display(f.db.upcast())),
+            None => f.write_str("{unnamed}"),
+        },
     };
 
-    f.write_str("\nwhere")?;
-
-    for (pred_idx, pred) in params.where_predicates.iter().enumerate() {
-        let prev_pred =
-            if pred_idx == 0 { None } else { Some(&params.where_predicates[pred_idx - 1]) };
+    let check_same_target = |pred1: &WherePredicate, pred2: &WherePredicate| match (pred1, pred2) {
+        (TypeBound { target: t1, .. }, TypeBound { target: t2, .. }) => t1 == t2,
+        (Lifetime { target: t1, .. }, Lifetime { target: t2, .. }) => t1 == t2,
+        (
+            ForLifetime { lifetimes: l1, target: t1, .. },
+            ForLifetime { lifetimes: l2, target: t2, .. },
+        ) => l1 == l2 && t1 == t2,
+        _ => false,
+    };
 
-        let new_predicate = |f: &mut HirFormatter<'_>| {
-            f.write_str(if pred_idx == 0 { "\n    " } else { ",\n    " })
-        };
+    let mut iter = params.where_predicates.iter().peekable();
+    while let Some(pred) = iter.next() {
+        if matches!(pred, TypeBound { target, .. } if is_unnamed_type_target(params, target)) {
+            continue;
+        }
 
+        f.write_str("\n    ")?;
         match pred {
-            WherePredicate::TypeBound { target, .. } if is_unnamed_type_target(target) => {}
-            WherePredicate::TypeBound { target, bound } => {
-                if matches!(prev_pred, Some(WherePredicate::TypeBound { target: target_, .. }) if target_ == target)
-                {
-                    f.write_str(" + ")?;
-                } else {
-                    new_predicate(f)?;
-                    write_target(target, f)?;
-                    f.write_str(": ")?;
-                }
+            TypeBound { target, bound } => {
+                write_target(target, f)?;
+                f.write_str(": ")?;
                 bound.hir_fmt(f)?;
             }
-            WherePredicate::Lifetime { target, bound } => {
-                if matches!(prev_pred, Some(WherePredicate::Lifetime { target: target_, .. }) if target_ == target)
-                {
-                    write!(f, " + {}", bound.name.display(f.db.upcast()))?;
-                } else {
-                    new_predicate(f)?;
-                    write!(
-                        f,
-                        "{}: {}",
-                        target.name.display(f.db.upcast()),
-                        bound.name.display(f.db.upcast())
-                    )?;
-                }
+            Lifetime { target, bound } => {
+                let target = target.name.display(f.db.upcast());
+                let bound = bound.name.display(f.db.upcast());
+                write!(f, "{target}: {bound}")?;
             }
-            WherePredicate::ForLifetime { lifetimes, target, bound } => {
-                if matches!(
-                    prev_pred,
-                    Some(WherePredicate::ForLifetime { lifetimes: lifetimes_, target: target_, .. })
-                    if lifetimes_ == lifetimes && target_ == target,
-                ) {
-                    f.write_str(" + ")?;
-                } else {
-                    new_predicate(f)?;
-                    f.write_str("for<")?;
-                    for (idx, lifetime) in lifetimes.iter().enumerate() {
-                        if idx != 0 {
-                            f.write_str(", ")?;
-                        }
-                        write!(f, "{}", lifetime.display(f.db.upcast()))?;
-                    }
-                    f.write_str("> ")?;
-                    write_target(target, f)?;
-                    f.write_str(": ")?;
-                }
+            ForLifetime { lifetimes, target, bound } => {
+                let lifetimes = lifetimes.iter().map(|it| it.display(f.db.upcast())).join(", ");
+                write!(f, "for<{lifetimes}> ")?;
+                write_target(target, f)?;
+                f.write_str(": ")?;
                 bound.hir_fmt(f)?;
             }
         }
-    }
 
-    // End of final predicate. There must be at least one predicate here.
-    f.write_char(',')?;
+        while let Some(nxt) = iter.next_if(|nxt| check_same_target(pred, nxt)) {
+            f.write_str(" + ")?;
+            match nxt {
+                TypeBound { bound, .. } | ForLifetime { bound, .. } => bound.hir_fmt(f)?,
+                Lifetime { bound, .. } => write!(f, "{}", bound.name.display(f.db.upcast()))?,
+            }
+        }
+        f.write_str(",")?;
+    }
 
-    Ok(true)
+    Ok(())
 }
 
 impl HirDisplay for Const {
@@ -689,17 +691,8 @@ impl HirDisplay for Static {
 
 impl HirDisplay for Trait {
     fn hir_fmt(&self, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
-        write_visibility(self.module(f.db).id, self.visibility(f.db), f)?;
-        let data = f.db.trait_data(self.id);
-        if data.is_unsafe {
-            f.write_str("unsafe ")?;
-        }
-        if data.is_auto {
-            f.write_str("auto ")?;
-        }
-        write!(f, "trait {}", data.name.display(f.db.upcast()))?;
+        write_trait_header(self, f)?;
         let def_id = GenericDefId::TraitId(self.id);
-        write_generic_params(def_id, f)?;
         let has_where_clause = write_where_clause(def_id, f)?;
 
         if let Some(limit) = f.entity_limit {
@@ -735,6 +728,20 @@ impl HirDisplay for Trait {
     }
 }
 
+fn write_trait_header(trait_: &Trait, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
+    write_visibility(trait_.module(f.db).id, trait_.visibility(f.db), f)?;
+    let data = f.db.trait_data(trait_.id);
+    if data.is_unsafe {
+        f.write_str("unsafe ")?;
+    }
+    if data.is_auto {
+        f.write_str("auto ")?;
+    }
+    write!(f, "trait {}", data.name.display(f.db.upcast()))?;
+    write_generic_params(GenericDefId::TraitId(trait_.id), f)?;
+    Ok(())
+}
+
 impl HirDisplay for TraitAlias {
     fn hir_fmt(&self, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
         write_visibility(self.module(f.db).id, self.visibility(f.db), f)?;