about summary refs log tree commit diff
diff options
context:
space:
mode:
authorChayim Refael Friedman <chayimfr@gmail.com>2025-04-14 00:55:27 +0300
committerChayim Refael Friedman <chayimfr@gmail.com>2025-04-22 15:19:35 +0300
commitf97d90f89eba2f59d29a8d2af555e638d9ddad36 (patch)
tree7508df334fe404b9ac52d09e4fead2597e8472fe
parentf0fa09621323e427ff9f7de856bd020bfa9e28f2 (diff)
downloadrust-f97d90f89eba2f59d29a8d2af555e638d9ddad36.tar.gz
rust-f97d90f89eba2f59d29a8d2af555e638d9ddad36.zip
Adjust for `salsa::Id::from_u32()` being unsafe
This impacts our manual `salsa::Id` wrappers. I refactored them a bit to improve safety.
-rw-r--r--src/tools/rust-analyzer/crates/hir-expand/src/lib.rs2
-rw-r--r--src/tools/rust-analyzer/crates/hir-ty/src/mapping.rs6
-rw-r--r--src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs4
-rw-r--r--src/tools/rust-analyzer/crates/span/src/hygiene.rs125
-rw-r--r--src/tools/rust-analyzer/crates/span/src/lib.rs126
5 files changed, 80 insertions, 183 deletions
diff --git a/src/tools/rust-analyzer/crates/hir-expand/src/lib.rs b/src/tools/rust-analyzer/crates/hir-expand/src/lib.rs
index cd2448bad4a..f0a9a2ad52c 100644
--- a/src/tools/rust-analyzer/crates/hir-expand/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/hir-expand/src/lib.rs
@@ -1051,7 +1051,7 @@ impl ExpandTo {
 
 intern::impl_internable!(ModPath, attrs::AttrInput);
 
-#[salsa::interned(no_lifetime)]
+#[salsa::interned(no_lifetime, debug)]
 #[doc(alias = "MacroFileId")]
 pub struct MacroCallId {
     pub loc: MacroCallLoc,
diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mapping.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mapping.rs
index f7511e5f63a..2abc1ac62a9 100644
--- a/src/tools/rust-analyzer/crates/hir-ty/src/mapping.rs
+++ b/src/tools/rust-analyzer/crates/hir-ty/src/mapping.rs
@@ -136,7 +136,8 @@ pub fn from_assoc_type_id(id: AssocTypeId) -> TypeAliasId {
 
 pub fn from_placeholder_idx(db: &dyn HirDatabase, idx: PlaceholderIndex) -> TypeOrConstParamId {
     assert_eq!(idx.ui, chalk_ir::UniverseIndex::ROOT);
-    let interned_id = FromId::from_id(Id::from_u32(idx.idx.try_into().unwrap()));
+    // SAFETY: We cannot really encapsulate this unfortunately, so just hope this is sound.
+    let interned_id = FromId::from_id(unsafe { Id::from_u32(idx.idx.try_into().unwrap()) });
     db.lookup_intern_type_or_const_param_id(interned_id)
 }
 
@@ -150,7 +151,8 @@ pub fn to_placeholder_idx(db: &dyn HirDatabase, id: TypeOrConstParamId) -> Place
 
 pub fn lt_from_placeholder_idx(db: &dyn HirDatabase, idx: PlaceholderIndex) -> LifetimeParamId {
     assert_eq!(idx.ui, chalk_ir::UniverseIndex::ROOT);
-    let interned_id = FromId::from_id(Id::from_u32(idx.idx.try_into().unwrap()));
+    // SAFETY: We cannot really encapsulate this unfortunately, so just hope this is sound.
+    let interned_id = FromId::from_id(unsafe { Id::from_u32(idx.idx.try_into().unwrap()) });
     db.lookup_intern_lifetime_param_id(interned_id)
 }
 
diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs
index 101c4b3105a..597ffa05d20 100644
--- a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs
+++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs
@@ -72,7 +72,9 @@ pub fn deserialize_span_data_index_map(map: &[u32]) -> SpanDataIndexMap {
                     ast_id: ErasedFileAstId::from_raw(ast_id),
                 },
                 range: TextRange::new(start.into(), end.into()),
-                ctx: SyntaxContext::from_u32(e),
+                // SAFETY: We only receive spans from the server. If someone mess up the communication UB can happen,
+                // but that will be their problem.
+                ctx: unsafe { SyntaxContext::from_u32(e) },
             }
         })
         .collect()
diff --git a/src/tools/rust-analyzer/crates/span/src/hygiene.rs b/src/tools/rust-analyzer/crates/span/src/hygiene.rs
index a2923cd2233..6022b5b1209 100644
--- a/src/tools/rust-analyzer/crates/span/src/hygiene.rs
+++ b/src/tools/rust-analyzer/crates/span/src/hygiene.rs
@@ -27,7 +27,10 @@ use crate::Edition;
 #[cfg(feature = "salsa")]
 #[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
 pub struct SyntaxContext(
-    salsa::Id,
+    /// # Invariant
+    ///
+    /// This is either a valid `salsa::Id` or a root `SyntaxContext`.
+    u32,
     std::marker::PhantomData<&'static salsa::plumbing::interned::Value<SyntaxContext>>,
 );
 
@@ -95,10 +98,11 @@ const _: () = {
         type Fields<'a> = SyntaxContextData;
         type Struct<'a> = SyntaxContext;
         fn struct_from_id<'db>(id: salsa::Id) -> Self::Struct<'db> {
-            SyntaxContext(id, std::marker::PhantomData)
+            SyntaxContext::from_salsa_id(id)
         }
         fn deref_struct(s: Self::Struct<'_>) -> salsa::Id {
-            s.0
+            s.as_salsa_id()
+                .expect("`SyntaxContext::deref_structs()` called on a root `SyntaxContext`")
         }
     }
     impl SyntaxContext {
@@ -115,12 +119,12 @@ const _: () = {
     }
     impl zalsa_::AsId for SyntaxContext {
         fn as_id(&self) -> salsa::Id {
-            self.0
+            self.as_salsa_id().expect("`SyntaxContext::as_id()` called on a root `SyntaxContext`")
         }
     }
     impl zalsa_::FromId for SyntaxContext {
         fn from_id(id: salsa::Id) -> Self {
-            Self(id, std::marker::PhantomData)
+            Self::from_salsa_id(id)
         }
     }
     unsafe impl Send for SyntaxContext {}
@@ -210,44 +214,44 @@ const _: () = {
         where
             Db: ?Sized + zalsa_::Database,
         {
-            if self.is_root() {
-                return None;
-            }
-            let fields = SyntaxContext::ingredient(db).fields(db.as_dyn_database(), self);
-            std::clone::Clone::clone(&fields.outer_expn)
+            let id = self.as_salsa_id()?;
+            let fields = SyntaxContext::ingredient(db).data(db.as_dyn_database(), id);
+            fields.outer_expn
         }
 
         pub fn outer_transparency<Db>(self, db: &'db Db) -> Transparency
         where
             Db: ?Sized + zalsa_::Database,
         {
-            if self.is_root() {
-                return Transparency::Opaque;
-            }
-            let fields = SyntaxContext::ingredient(db).fields(db.as_dyn_database(), self);
-            std::clone::Clone::clone(&fields.outer_transparency)
+            let Some(id) = self.as_salsa_id() else { return Transparency::Opaque };
+            let fields = SyntaxContext::ingredient(db).data(db.as_dyn_database(), id);
+            fields.outer_transparency
         }
 
         pub fn edition<Db>(self, db: &'db Db) -> Edition
         where
             Db: ?Sized + zalsa_::Database,
         {
-            if self.is_root() {
-                return Edition::from_u32(SyntaxContext::MAX_ID - self.0.as_u32());
+            match self.as_salsa_id() {
+                Some(id) => {
+                    let fields = SyntaxContext::ingredient(db).data(db.as_dyn_database(), id);
+                    fields.edition
+                }
+                None => Edition::from_u32(SyntaxContext::MAX_ID - self.into_u32()),
             }
-            let fields = SyntaxContext::ingredient(db).fields(db.as_dyn_database(), self);
-            std::clone::Clone::clone(&fields.edition)
         }
 
         pub fn parent<Db>(self, db: &'db Db) -> SyntaxContext
         where
             Db: ?Sized + zalsa_::Database,
         {
-            if self.is_root() {
-                return self;
+            match self.as_salsa_id() {
+                Some(id) => {
+                    let fields = SyntaxContext::ingredient(db).data(db.as_dyn_database(), id);
+                    fields.parent
+                }
+                None => self,
             }
-            let fields = SyntaxContext::ingredient(db).fields(db.as_dyn_database(), self);
-            std::clone::Clone::clone(&fields.parent)
         }
 
         /// This context, but with all transparent and semi-transparent expansions filtered away.
@@ -255,11 +259,13 @@ const _: () = {
         where
             Db: ?Sized + zalsa_::Database,
         {
-            if self.is_root() {
-                return self;
+            match self.as_salsa_id() {
+                Some(id) => {
+                    let fields = SyntaxContext::ingredient(db).data(db.as_dyn_database(), id);
+                    fields.opaque
+                }
+                None => self,
             }
-            let fields = SyntaxContext::ingredient(db).fields(db.as_dyn_database(), self);
-            std::clone::Clone::clone(&fields.opaque)
         }
 
         /// This context, but with all transparent expansions filtered away.
@@ -267,33 +273,19 @@ const _: () = {
         where
             Db: ?Sized + zalsa_::Database,
         {
-            if self.is_root() {
-                return self;
+            match self.as_salsa_id() {
+                Some(id) => {
+                    let fields = SyntaxContext::ingredient(db).data(db.as_dyn_database(), id);
+                    fields.opaque_and_semitransparent
+                }
+                None => self,
             }
-            let fields = SyntaxContext::ingredient(db).fields(db.as_dyn_database(), self);
-            std::clone::Clone::clone(&fields.opaque_and_semitransparent)
-        }
-
-        pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-            salsa::with_attached_database(|db| {
-                let fields = SyntaxContext::ingredient(db).fields(db.as_dyn_database(), this);
-                let mut f = f.debug_struct("SyntaxContextData");
-                let f = f.field("outer_expn", &fields.outer_expn);
-                let f = f.field("outer_transparency", &fields.outer_expn);
-                let f = f.field("edition", &fields.edition);
-                let f = f.field("parent", &fields.parent);
-                let f = f.field("opaque", &fields.opaque);
-                let f = f.field("opaque_and_semitransparent", &fields.opaque_and_semitransparent);
-                f.finish()
-            })
-            .unwrap_or_else(|| {
-                f.debug_tuple("SyntaxContextData").field(&zalsa_::AsId::as_id(&this)).finish()
-            })
         }
     }
 };
 
 impl SyntaxContext {
+    #[inline]
     pub fn is_root(self) -> bool {
         (SyntaxContext::MAX_ID - Edition::LATEST as u32) <= self.into_u32()
             && self.into_u32() <= (SyntaxContext::MAX_ID - Edition::Edition2015 as u32)
@@ -307,9 +299,11 @@ impl SyntaxContext {
     }
 
     /// The root context, which is the parent of all other contexts. All [`FileId`]s have this context.
+    #[inline]
     pub const fn root(edition: Edition) -> Self {
         let edition = edition as u32;
-        SyntaxContext::from_u32(SyntaxContext::MAX_ID - edition)
+        // SAFETY: Roots are valid `SyntaxContext`s
+        unsafe { SyntaxContext::from_u32(SyntaxContext::MAX_ID - edition) }
     }
 }
 
@@ -317,12 +311,34 @@ impl SyntaxContext {
 impl SyntaxContext {
     const MAX_ID: u32 = salsa::Id::MAX_U32 - 1;
 
+    #[inline]
     pub const fn into_u32(self) -> u32 {
-        self.0.as_u32()
+        self.0
     }
 
-    pub const fn from_u32(u32: u32) -> Self {
-        Self(salsa::Id::from_u32(u32), std::marker::PhantomData)
+    /// # Safety
+    ///
+    /// The ID must be a valid `SyntaxContext`.
+    #[inline]
+    pub const unsafe fn from_u32(u32: u32) -> Self {
+        // INVARIANT: Our precondition.
+        Self(u32, std::marker::PhantomData)
+    }
+
+    #[inline]
+    fn as_salsa_id(self) -> Option<salsa::Id> {
+        if self.is_root() {
+            None
+        } else {
+            // SAFETY: By our invariant, this is either a root (which we verified it's not) or a valid `salsa::Id`.
+            unsafe { Some(salsa::Id::from_u32(self.0)) }
+        }
+    }
+
+    #[inline]
+    fn from_salsa_id(id: salsa::Id) -> Self {
+        // SAFETY: This comes from a Salsa ID.
+        unsafe { Self::from_u32(id.as_u32()) }
     }
 }
 #[cfg(not(feature = "salsa"))]
@@ -342,7 +358,10 @@ impl SyntaxContext {
         self.0
     }
 
-    pub const fn from_u32(u32: u32) -> Self {
+    /// # Safety
+    ///
+    /// None. This is always safe to call without the `salsa` feature.
+    pub const unsafe fn from_u32(u32: u32) -> Self {
         Self(u32)
     }
 }
diff --git a/src/tools/rust-analyzer/crates/span/src/lib.rs b/src/tools/rust-analyzer/crates/span/src/lib.rs
index 67f49928f88..54f90908f36 100644
--- a/src/tools/rust-analyzer/crates/span/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/span/src/lib.rs
@@ -184,16 +184,6 @@ impl EditionedFileId {
 mod salsa {
     #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
     pub struct Id(u32);
-
-    impl Id {
-        pub(crate) const fn from_u32(u32: u32) -> Self {
-            Self(u32)
-        }
-
-        pub(crate) const fn as_u32(self) -> u32 {
-            self.0
-        }
-    }
 }
 
 /// Input to the analyzer is a set of files, where each file is identified by
@@ -216,127 +206,11 @@ mod salsa {
 #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 pub struct HirFileId(pub salsa::Id);
 
-impl From<MacroCallId> for HirFileId {
-    fn from(value: MacroCallId) -> Self {
-        value.as_file()
-    }
-}
-
-impl fmt::Debug for HirFileId {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        self.repr().fmt(f)
-    }
-}
-
-impl PartialEq<FileId> for HirFileId {
-    fn eq(&self, &other: &FileId) -> bool {
-        self.file_id().map(EditionedFileId::file_id) == Some(other)
-    }
-}
-impl PartialEq<HirFileId> for FileId {
-    fn eq(&self, other: &HirFileId) -> bool {
-        other.file_id().map(EditionedFileId::file_id) == Some(*self)
-    }
-}
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
-pub struct MacroFileId {
-    pub macro_call_id: MacroCallId,
-}
-
 /// `MacroCallId` identifies a particular macro invocation, like
 /// `println!("Hello, {}", world)`.
 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
 pub struct MacroCallId(pub salsa::Id);
 
-impl MacroCallId {
-    pub const MAX_ID: u32 = 0x7fff_ffff;
-
-    pub fn as_file(self) -> HirFileId {
-        MacroFileId { macro_call_id: self }.into()
-    }
-
-    pub fn as_macro_file(self) -> MacroFileId {
-        MacroFileId { macro_call_id: self }
-    }
-}
-
-#[derive(Clone, Copy, PartialEq, Eq, Hash)]
-pub enum HirFileIdRepr {
-    FileId(EditionedFileId),
-    MacroFile(MacroFileId),
-}
-
-impl fmt::Debug for HirFileIdRepr {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match self {
-            Self::FileId(arg0) => arg0.fmt(f),
-            Self::MacroFile(arg0) => {
-                f.debug_tuple("MacroFile").field(&arg0.macro_call_id.0).finish()
-            }
-        }
-    }
-}
-
-impl From<EditionedFileId> for HirFileId {
-    #[allow(clippy::let_unit_value)]
-    fn from(id: EditionedFileId) -> Self {
-        assert!(id.as_u32() <= Self::MAX_HIR_FILE_ID, "FileId index {} is too large", id.as_u32());
-        HirFileId(salsa::Id::from_u32(id.0))
-    }
-}
-
-impl From<MacroFileId> for HirFileId {
-    #[allow(clippy::let_unit_value)]
-    fn from(MacroFileId { macro_call_id: MacroCallId(id) }: MacroFileId) -> Self {
-        let id: u32 = id.as_u32();
-        assert!(id <= Self::MAX_HIR_FILE_ID, "MacroCallId index {id} is too large");
-        HirFileId(salsa::Id::from_u32(id | Self::MACRO_FILE_TAG_MASK))
-    }
-}
-
-impl HirFileId {
-    const MAX_HIR_FILE_ID: u32 = u32::MAX ^ Self::MACRO_FILE_TAG_MASK;
-    const MACRO_FILE_TAG_MASK: u32 = 1 << 31;
-
-    #[inline]
-    pub fn is_macro(self) -> bool {
-        self.0.as_u32() & Self::MACRO_FILE_TAG_MASK != 0
-    }
-
-    #[inline]
-    pub fn macro_file(self) -> Option<MacroFileId> {
-        match self.0.as_u32() & Self::MACRO_FILE_TAG_MASK {
-            0 => None,
-            _ => Some(MacroFileId {
-                macro_call_id: MacroCallId(salsa::Id::from_u32(
-                    self.0.as_u32() ^ Self::MACRO_FILE_TAG_MASK,
-                )),
-            }),
-        }
-    }
-
-    #[inline]
-    pub fn file_id(self) -> Option<EditionedFileId> {
-        match self.0.as_u32() & Self::MACRO_FILE_TAG_MASK {
-            0 => Some(EditionedFileId(self.0.as_u32())),
-            _ => None,
-        }
-    }
-
-    #[inline]
-    pub fn repr(self) -> HirFileIdRepr {
-        match self.0.as_u32() & Self::MACRO_FILE_TAG_MASK {
-            0 => HirFileIdRepr::FileId(EditionedFileId(self.0.as_u32())),
-            _ => HirFileIdRepr::MacroFile(MacroFileId {
-                macro_call_id: MacroCallId(salsa::Id::from_u32(
-                    self.0.as_u32() ^ Self::MACRO_FILE_TAG_MASK,
-                )),
-            }),
-        }
-    }
-}
-
 /// Legacy span type, only defined here as it is still used by the proc-macro server.
 /// While rust-analyzer doesn't use this anymore at all, RustRover relies on the legacy type for
 /// proc-macro expansion.