about summary refs log tree commit diff
diff options
context:
space:
mode:
authordavidsemakula <hello@davidsemakula.com>2024-02-03 13:06:52 +0300
committerdavidsemakula <hello@davidsemakula.com>2024-02-03 15:51:14 +0300
commitc0071ace5a13b93e7570d7b05bcbc9ddc721833c (patch)
tree3af4116d34ccc36675ebd753b35ac2d233300fcf
parent46d79e21b5e2f5d9ce86a22d209453fcc8205372 (diff)
downloadrust-c0071ace5a13b93e7570d7b05bcbc9ddc721833c.tar.gz
rust-c0071ace5a13b93e7570d7b05bcbc9ddc721833c.zip
refactor `hir-ty::diagnostics::decl_check`
-rw-r--r--crates/hir-ty/src/diagnostics/decl_check.rs451
1 files changed, 141 insertions, 310 deletions
diff --git a/crates/hir-ty/src/diagnostics/decl_check.rs b/crates/hir-ty/src/diagnostics/decl_check.rs
index 4909b82f88c..75492ec9fec 100644
--- a/crates/hir-ty/src/diagnostics/decl_check.rs
+++ b/crates/hir-ty/src/diagnostics/decl_check.rs
@@ -17,6 +17,7 @@ use std::fmt;
 
 use hir_def::{
     data::adt::VariantData,
+    db::DefDatabase,
     hir::{Pat, PatId},
     src::HasSource,
     AdtId, AttrDefId, ConstId, EnumId, FunctionId, ItemContainerId, Lookup, ModuleDefId, ModuleId,
@@ -248,45 +249,25 @@ impl<'a> DeclValidator<'a> {
 
         // Check the module name.
         let Some(module_name) = module_id.name(self.db.upcast()) else { return };
-        let module_name_replacement =
+        let Some(module_name_replacement) =
             module_name.as_str().and_then(to_lower_snake_case).map(|new_name| Replacement {
                 current_name: module_name,
                 suggested_text: new_name,
                 expected_case: CaseType::LowerSnakeCase,
-            });
-
-        if let Some(module_name_replacement) = module_name_replacement {
-            let module_data = &module_id.def_map(self.db.upcast())[module_id.local_id];
-            let module_src = module_data.declaration_source(self.db.upcast());
-
-            if let Some(module_src) = module_src {
-                let ast_ptr = match module_src.value.name() {
-                    Some(name) => name,
-                    None => {
-                        never!(
-                            "Replacement ({:?}) was generated for a module without a name: {:?}",
-                            module_name_replacement,
-                            module_src
-                        );
-                        return;
-                    }
-                };
-
-                let diagnostic = IncorrectCase {
-                    file: module_src.file_id,
-                    ident_type: IdentType::Module,
-                    ident: AstPtr::new(&ast_ptr),
-                    expected_case: module_name_replacement.expected_case,
-                    ident_text: module_name_replacement
-                        .current_name
-                        .display(self.db.upcast())
-                        .to_string(),
-                    suggested_text: module_name_replacement.suggested_text,
-                };
-
-                self.sink.push(diagnostic);
-            }
-        }
+            })
+        else {
+            return;
+        };
+        let module_data = &module_id.def_map(self.db.upcast())[module_id.local_id];
+        let Some(module_src) = module_data.declaration_source(self.db.upcast()) else {
+            return;
+        };
+        self.create_incorrect_case_diagnostic_for_ast_node(
+            module_name_replacement,
+            module_src.file_id,
+            &module_src.value,
+            IdentType::Module,
+        );
     }
 
     fn validate_trait(&mut self, trait_id: TraitId) {
@@ -297,37 +278,12 @@ impl<'a> DeclValidator<'a> {
 
         // Check the trait name.
         let data = self.db.trait_data(trait_id);
-        let trait_name = data.name.display(self.db.upcast()).to_string();
-        let trait_name_replacement = to_camel_case(&trait_name).map(|new_name| Replacement {
-            current_name: data.name.clone(),
-            suggested_text: new_name,
-            expected_case: CaseType::UpperCamelCase,
-        });
-
-        if let Some(replacement) = trait_name_replacement {
-            let trait_loc = trait_id.lookup(self.db.upcast());
-            let trait_src = trait_loc.source(self.db.upcast());
-
-            let Some(ast_ptr) = trait_src.value.name() else {
-                never!(
-                    "Replacement ({:?}) was generated for a trait without a name: {:?}",
-                    replacement,
-                    trait_src
-                );
-                return;
-            };
-
-            let diagnostic = IncorrectCase {
-                file: trait_src.file_id,
-                ident_type: IdentType::Trait,
-                ident: AstPtr::new(&ast_ptr),
-                expected_case: replacement.expected_case,
-                ident_text: trait_name,
-                suggested_text: replacement.suggested_text,
-            };
-
-            self.sink.push(diagnostic);
-        }
+        self.create_incorrect_case_diagnostic_for_item_name(
+            trait_id,
+            &data.name,
+            CaseType::UpperCamelCase,
+            IdentType::Trait,
+        );
     }
 
     fn validate_func(&mut self, func: FunctionId) {
@@ -348,18 +304,12 @@ impl<'a> DeclValidator<'a> {
         // Check the function name.
         if !is_trait_impl_assoc_fn {
             let data = self.db.function_data(func);
-            let function_name = data.name.display(self.db.upcast()).to_string();
-            let fn_name_replacement =
-                to_lower_snake_case(&function_name).map(|new_name| Replacement {
-                    current_name: data.name.clone(),
-                    suggested_text: new_name,
-                    expected_case: CaseType::LowerSnakeCase,
-                });
-            // If there is at least one element to spawn a warning on,
-            // go to the source map and generate a warning.
-            if let Some(fn_name_replacement) = fn_name_replacement {
-                self.create_incorrect_case_diagnostic_for_func(func, fn_name_replacement);
-            }
+            self.create_incorrect_case_diagnostic_for_item_name(
+                func,
+                &data.name,
+                CaseType::LowerSnakeCase,
+                IdentType::Function,
+            );
         } else {
             cov_mark::hit!(trait_impl_assoc_func_name_incorrect_case_ignored);
         }
@@ -399,47 +349,12 @@ impl<'a> DeclValidator<'a> {
                 ))
             })
             .collect();
-        self.create_incorrect_case_diagnostic_for_variables(func, pats_replacements);
-    }
-
-    /// Given the information about incorrect names in the function declaration, looks up into the source code
-    /// for exact locations and adds diagnostics into the sink.
-    fn create_incorrect_case_diagnostic_for_func(
-        &mut self,
-        func: FunctionId,
-        fn_name_replacement: Replacement,
-    ) {
-        let fn_loc = func.lookup(self.db.upcast());
-        let fn_src = fn_loc.source(self.db.upcast());
-
-        // Diagnostic for function name.
-        let ast_ptr = match fn_src.value.name() {
-            Some(name) => name,
-            None => {
-                never!(
-                    "Replacement ({:?}) was generated for a function without a name: {:?}",
-                    fn_name_replacement,
-                    fn_src
-                );
-                return;
-            }
-        };
-
-        let diagnostic = IncorrectCase {
-            file: fn_src.file_id,
-            ident_type: IdentType::Function,
-            ident: AstPtr::new(&ast_ptr),
-            expected_case: fn_name_replacement.expected_case,
-            ident_text: fn_name_replacement.current_name.display(self.db.upcast()).to_string(),
-            suggested_text: fn_name_replacement.suggested_text,
-        };
-
-        self.sink.push(diagnostic);
+        self.create_incorrect_case_diagnostic_for_func_variables(func, pats_replacements);
     }
 
     /// Given the information about incorrect variable names, looks up into the source code
     /// for exact locations and adds diagnostics into the sink.
-    fn create_incorrect_case_diagnostic_for_variables(
+    fn create_incorrect_case_diagnostic_for_func_variables(
         &mut self,
         func: FunctionId,
         pats_replacements: Vec<(PatId, Replacement)>,
@@ -460,10 +375,6 @@ impl<'a> DeclValidator<'a> {
                         Some(parent) => parent,
                         None => continue,
                     };
-                    let name_ast = match ident_pat.name() {
-                        Some(name_ast) => name_ast,
-                        None => continue,
-                    };
 
                     let is_param = ast::Param::can_cast(parent.kind());
 
@@ -481,16 +392,12 @@ impl<'a> DeclValidator<'a> {
                     let ident_type =
                         if is_param { IdentType::Parameter } else { IdentType::Variable };
 
-                    let diagnostic = IncorrectCase {
-                        file: source_ptr.file_id,
+                    self.create_incorrect_case_diagnostic_for_ast_node(
+                        replacement,
+                        source_ptr.file_id,
+                        &ident_pat,
                         ident_type,
-                        ident: AstPtr::new(&name_ast),
-                        expected_case: replacement.expected_case,
-                        ident_text: replacement.current_name.display(self.db.upcast()).to_string(),
-                        suggested_text: replacement.suggested_text,
-                    };
-
-                    self.sink.push(diagnostic);
+                    );
                 }
             }
         }
@@ -504,20 +411,17 @@ impl<'a> DeclValidator<'a> {
         let non_snake_case_allowed = self.allowed(struct_id.into(), allow::NON_SNAKE_CASE, false);
 
         // Check the structure name.
-        let struct_name = data.name.display(self.db.upcast()).to_string();
-        let struct_name_replacement = if !non_camel_case_allowed {
-            to_camel_case(&struct_name).map(|new_name| Replacement {
-                current_name: data.name.clone(),
-                suggested_text: new_name,
-                expected_case: CaseType::UpperCamelCase,
-            })
-        } else {
-            None
-        };
+        if !non_camel_case_allowed {
+            self.create_incorrect_case_diagnostic_for_item_name(
+                struct_id,
+                &data.name,
+                CaseType::UpperCamelCase,
+                IdentType::Structure,
+            );
+        }
 
         // Check the field names.
         let mut struct_fields_replacements = Vec::new();
-
         if !non_snake_case_allowed {
             if let VariantData::Record(fields) = data.variant_data.as_ref() {
                 for (_, field) in fields.iter() {
@@ -535,54 +439,27 @@ impl<'a> DeclValidator<'a> {
         }
 
         // If there is at least one element to spawn a warning on, go to the source map and generate a warning.
-        self.create_incorrect_case_diagnostic_for_struct(
+        self.create_incorrect_case_diagnostic_for_struct_fields(
             struct_id,
-            struct_name_replacement,
             struct_fields_replacements,
         );
     }
 
-    /// Given the information about incorrect names in the struct declaration, looks up into the source code
-    /// for exact locations and adds diagnostics into the sink.
-    fn create_incorrect_case_diagnostic_for_struct(
+    /// Given the information about incorrect names for struct fields,
+    /// looks up into the source code for exact locations and adds diagnostics into the sink.
+    fn create_incorrect_case_diagnostic_for_struct_fields(
         &mut self,
         struct_id: StructId,
-        struct_name_replacement: Option<Replacement>,
         struct_fields_replacements: Vec<Replacement>,
     ) {
         // XXX: Only look at sources if we do have incorrect names.
-        if struct_name_replacement.is_none() && struct_fields_replacements.is_empty() {
+        if struct_fields_replacements.is_empty() {
             return;
         }
 
         let struct_loc = struct_id.lookup(self.db.upcast());
         let struct_src = struct_loc.source(self.db.upcast());
 
-        if let Some(replacement) = struct_name_replacement {
-            let ast_ptr = match struct_src.value.name() {
-                Some(name) => name,
-                None => {
-                    never!(
-                        "Replacement ({:?}) was generated for a structure without a name: {:?}",
-                        replacement,
-                        struct_src
-                    );
-                    return;
-                }
-            };
-
-            let diagnostic = IncorrectCase {
-                file: struct_src.file_id,
-                ident_type: IdentType::Structure,
-                ident: AstPtr::new(&ast_ptr),
-                expected_case: replacement.expected_case,
-                ident_text: replacement.current_name.display(self.db.upcast()).to_string(),
-                suggested_text: replacement.suggested_text,
-            };
-
-            self.sink.push(diagnostic);
-        }
-
         let struct_fields_list = match struct_src.value.field_list() {
             Some(ast::FieldList::RecordFieldList(fields)) => fields,
             _ => {
@@ -638,15 +515,15 @@ impl<'a> DeclValidator<'a> {
         }
 
         // Check the enum name.
-        let enum_name = data.name.display(self.db.upcast()).to_string();
-        let enum_name_replacement = to_camel_case(&enum_name).map(|new_name| Replacement {
-            current_name: data.name.clone(),
-            suggested_text: new_name,
-            expected_case: CaseType::UpperCamelCase,
-        });
+        self.create_incorrect_case_diagnostic_for_item_name(
+            enum_id,
+            &data.name,
+            CaseType::UpperCamelCase,
+            IdentType::Enum,
+        );
 
         // Check the field names.
-        let enum_fields_replacements = data
+        let enum_variants_replacements = data
             .variants
             .iter()
             .filter_map(|(_, name)| {
@@ -659,54 +536,24 @@ impl<'a> DeclValidator<'a> {
             .collect();
 
         // If there is at least one element to spawn a warning on, go to the source map and generate a warning.
-        self.create_incorrect_case_diagnostic_for_enum(
-            enum_id,
-            enum_name_replacement,
-            enum_fields_replacements,
-        )
+        self.create_incorrect_case_diagnostic_for_enum_variants(enum_id, enum_variants_replacements)
     }
 
-    /// Given the information about incorrect names in the struct declaration, looks up into the source code
-    /// for exact locations and adds diagnostics into the sink.
-    fn create_incorrect_case_diagnostic_for_enum(
+    /// Given the information about incorrect names for enum variants,
+    /// looks up into the source code for exact locations and adds diagnostics into the sink.
+    fn create_incorrect_case_diagnostic_for_enum_variants(
         &mut self,
         enum_id: EnumId,
-        enum_name_replacement: Option<Replacement>,
         enum_variants_replacements: Vec<Replacement>,
     ) {
         // XXX: only look at sources if we do have incorrect names
-        if enum_name_replacement.is_none() && enum_variants_replacements.is_empty() {
+        if enum_variants_replacements.is_empty() {
             return;
         }
 
         let enum_loc = enum_id.lookup(self.db.upcast());
         let enum_src = enum_loc.source(self.db.upcast());
 
-        if let Some(replacement) = enum_name_replacement {
-            let ast_ptr = match enum_src.value.name() {
-                Some(name) => name,
-                None => {
-                    never!(
-                        "Replacement ({:?}) was generated for a enum without a name: {:?}",
-                        replacement,
-                        enum_src
-                    );
-                    return;
-                }
-            };
-
-            let diagnostic = IncorrectCase {
-                file: enum_src.file_id,
-                ident_type: IdentType::Enum,
-                ident: AstPtr::new(&ast_ptr),
-                expected_case: replacement.expected_case,
-                ident_text: replacement.current_name.display(self.db.upcast()).to_string(),
-                suggested_text: replacement.suggested_text,
-            };
-
-            self.sink.push(diagnostic);
-        }
-
         let enum_variants_list = match enum_src.value.variant_list() {
             Some(variants) => variants,
             _ => {
@@ -760,47 +607,20 @@ impl<'a> DeclValidator<'a> {
             return;
         }
 
-        let data = self.db.const_data(const_id);
-
         if self.allowed(const_id.into(), allow::NON_UPPER_CASE_GLOBAL, false) {
             return;
         }
 
-        let name = match &data.name {
-            Some(name) => name,
-            None => return,
-        };
-
-        let const_name = name.to_smol_str();
-        let replacement = if let Some(new_name) = to_upper_snake_case(&const_name) {
-            Replacement {
-                current_name: name.clone(),
-                suggested_text: new_name,
-                expected_case: CaseType::UpperSnakeCase,
-            }
-        } else {
-            // Nothing to do here.
+        let data = self.db.const_data(const_id);
+        let Some(name) = &data.name else {
             return;
         };
-
-        let const_loc = const_id.lookup(self.db.upcast());
-        let const_src = const_loc.source(self.db.upcast());
-
-        let ast_ptr = match const_src.value.name() {
-            Some(name) => name,
-            None => return,
-        };
-
-        let diagnostic = IncorrectCase {
-            file: const_src.file_id,
-            ident_type: IdentType::Constant,
-            ident: AstPtr::new(&ast_ptr),
-            expected_case: replacement.expected_case,
-            ident_text: replacement.current_name.display(self.db.upcast()).to_string(),
-            suggested_text: replacement.suggested_text,
-        };
-
-        self.sink.push(diagnostic);
+        self.create_incorrect_case_diagnostic_for_item_name(
+            const_id,
+            name,
+            CaseType::UpperSnakeCase,
+            IdentType::Constant,
+        );
     }
 
     fn validate_static(&mut self, static_id: StaticId) {
@@ -814,38 +634,12 @@ impl<'a> DeclValidator<'a> {
             return;
         }
 
-        let name = &data.name;
-
-        let static_name = name.to_smol_str();
-        let replacement = if let Some(new_name) = to_upper_snake_case(&static_name) {
-            Replacement {
-                current_name: name.clone(),
-                suggested_text: new_name,
-                expected_case: CaseType::UpperSnakeCase,
-            }
-        } else {
-            // Nothing to do here.
-            return;
-        };
-
-        let static_loc = static_id.lookup(self.db.upcast());
-        let static_src = static_loc.source(self.db.upcast());
-
-        let ast_ptr = match static_src.value.name() {
-            Some(name) => name,
-            None => return,
-        };
-
-        let diagnostic = IncorrectCase {
-            file: static_src.file_id,
-            ident_type: IdentType::StaticVariable,
-            ident: AstPtr::new(&ast_ptr),
-            expected_case: replacement.expected_case,
-            ident_text: replacement.current_name.display(self.db.upcast()).to_string(),
-            suggested_text: replacement.suggested_text,
-        };
-
-        self.sink.push(diagnostic);
+        self.create_incorrect_case_diagnostic_for_item_name(
+            static_id,
+            &data.name,
+            CaseType::UpperSnakeCase,
+            IdentType::StaticVariable,
+        );
     }
 
     fn validate_type_alias(&mut self, type_alias_id: TypeAliasId) {
@@ -862,38 +656,75 @@ impl<'a> DeclValidator<'a> {
 
         // Check the type alias name.
         let data = self.db.type_alias_data(type_alias_id);
-        let type_alias_name = data.name.display(self.db.upcast()).to_string();
-        let type_alias_name_replacement =
-            to_camel_case(&type_alias_name).map(|new_name| Replacement {
-                current_name: data.name.clone(),
-                suggested_text: new_name,
-                expected_case: CaseType::UpperCamelCase,
-            });
-
-        if let Some(replacement) = type_alias_name_replacement {
-            let type_alias_loc = type_alias_id.lookup(self.db.upcast());
-            let type_alias_src = type_alias_loc.source(self.db.upcast());
-
-            let Some(ast_ptr) = type_alias_src.value.name() else {
-                never!(
-                    "Replacement ({:?}) was generated for a type alias without a name: {:?}",
-                    replacement,
-                    type_alias_src
-                );
-                return;
-            };
+        self.create_incorrect_case_diagnostic_for_item_name(
+            type_alias_id,
+            &data.name,
+            CaseType::UpperCamelCase,
+            IdentType::TypeAlias,
+        );
+    }
 
-            let diagnostic = IncorrectCase {
-                file: type_alias_src.file_id,
-                ident_type: IdentType::TypeAlias,
-                ident: AstPtr::new(&ast_ptr),
-                expected_case: replacement.expected_case,
-                ident_text: type_alias_name,
-                suggested_text: replacement.suggested_text,
-            };
+    fn create_incorrect_case_diagnostic_for_item_name<N, S, L>(
+        &mut self,
+        item_id: L,
+        name: &Name,
+        expected_case: CaseType,
+        ident_type: IdentType,
+    ) where
+        N: AstNode + HasName + fmt::Debug,
+        S: HasSource<Value = N>,
+        L: Lookup<Data = S, Database<'a> = dyn DefDatabase + 'a>,
+    {
+        let to_expected_case_type = match expected_case {
+            CaseType::LowerSnakeCase => to_lower_snake_case,
+            CaseType::UpperSnakeCase => to_upper_snake_case,
+            CaseType::UpperCamelCase => to_camel_case,
+        };
+        let Some(replacement) = to_expected_case_type(&name.to_smol_str()).map(|new_name| {
+            Replacement { current_name: name.clone(), suggested_text: new_name, expected_case }
+        }) else {
+            return;
+        };
 
-            self.sink.push(diagnostic);
-        }
+        let item_loc = item_id.lookup(self.db.upcast());
+        let item_src = item_loc.source(self.db.upcast());
+        self.create_incorrect_case_diagnostic_for_ast_node(
+            replacement,
+            item_src.file_id,
+            &item_src.value,
+            ident_type,
+        );
+    }
+
+    fn create_incorrect_case_diagnostic_for_ast_node<T>(
+        &mut self,
+        replacement: Replacement,
+        file_id: HirFileId,
+        node: &T,
+        ident_type: IdentType,
+    ) where
+        T: AstNode + HasName + fmt::Debug,
+    {
+        let Some(name_ast) = node.name() else {
+            never!(
+                "Replacement ({:?}) was generated for a {:?} without a name: {:?}",
+                replacement,
+                ident_type,
+                node
+            );
+            return;
+        };
+
+        let diagnostic = IncorrectCase {
+            file: file_id,
+            ident_type,
+            ident: AstPtr::new(&name_ast),
+            expected_case: replacement.expected_case,
+            ident_text: replacement.current_name.display(self.db.upcast()).to_string(),
+            suggested_text: replacement.suggested_text,
+        };
+
+        self.sink.push(diagnostic);
     }
 
     fn is_trait_impl_container(&self, container_id: ItemContainerId) -> bool {