about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2025-03-01 19:28:04 +0000
committerMichael Goulet <michael@errs.io>2025-03-03 01:34:09 +0000
commit83fa2faf239b4bcde67eda7adf0a1a10dfca620d (patch)
tree26b960cb31d5e388b9f610c03fe33059434f0f59
parentf4a216d28ee635afce685b4206e713579f66e130 (diff)
downloadrust-83fa2faf239b4bcde67eda7adf0a1a10dfca620d.tar.gz
rust-83fa2faf239b4bcde67eda7adf0a1a10dfca620d.zip
Fix pretty printing of unsafe binders
-rw-r--r--compiler/rustc_middle/src/ty/print/pretty.rs117
-rw-r--r--compiler/rustc_symbol_mangling/src/v0.rs7
-rw-r--r--compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs6
-rw-r--r--tests/ui/unsafe-binders/type-mismatch.rs9
-rw-r--r--tests/ui/unsafe-binders/type-mismatch.stderr34
5 files changed, 131 insertions, 42 deletions
diff --git a/compiler/rustc_middle/src/ty/print/pretty.rs b/compiler/rustc_middle/src/ty/print/pretty.rs
index c77b37a302b..f1e82e8816c 100644
--- a/compiler/rustc_middle/src/ty/print/pretty.rs
+++ b/compiler/rustc_middle/src/ty/print/pretty.rs
@@ -133,6 +133,20 @@ pub macro with_no_queries($e:expr) {{
     ))
 }}
 
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum WrapBinderMode {
+    ForAll,
+    Unsafe,
+}
+impl WrapBinderMode {
+    pub fn start_str(self) -> &'static str {
+        match self {
+            WrapBinderMode::ForAll => "for<",
+            WrapBinderMode::Unsafe => "unsafe<",
+        }
+    }
+}
+
 /// The "region highlights" are used to control region printing during
 /// specific error messages. When a "region highlight" is enabled, it
 /// gives an alternate way to print specific regions. For now, we
@@ -219,7 +233,11 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
         self.print_def_path(def_id, args)
     }
 
-    fn in_binder<T>(&mut self, value: &ty::Binder<'tcx, T>) -> Result<(), PrintError>
+    fn in_binder<T>(
+        &mut self,
+        value: &ty::Binder<'tcx, T>,
+        _mode: WrapBinderMode,
+    ) -> Result<(), PrintError>
     where
         T: Print<'tcx, Self> + TypeFoldable<TyCtxt<'tcx>>,
     {
@@ -229,6 +247,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
     fn wrap_binder<T, F: FnOnce(&T, &mut Self) -> Result<(), fmt::Error>>(
         &mut self,
         value: &ty::Binder<'tcx, T>,
+        _mode: WrapBinderMode,
         f: F,
     ) -> Result<(), PrintError>
     where
@@ -703,8 +722,9 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
             }
             ty::FnPtr(ref sig_tys, hdr) => p!(print(sig_tys.with(hdr))),
             ty::UnsafeBinder(ref bound_ty) => {
-                // FIXME(unsafe_binders): Make this print `unsafe<>` rather than `for<>`.
-                self.wrap_binder(bound_ty, |ty, cx| cx.pretty_print_type(*ty))?;
+                self.wrap_binder(bound_ty, WrapBinderMode::Unsafe, |ty, cx| {
+                    cx.pretty_print_type(*ty)
+                })?;
             }
             ty::Infer(infer_ty) => {
                 if self.should_print_verbose() {
@@ -1067,29 +1087,33 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
             };
 
             if let Some(return_ty) = entry.return_ty {
-                self.wrap_binder(&bound_args_and_self_ty, |(args, _), cx| {
-                    define_scoped_cx!(cx);
-                    p!(write("{}", tcx.item_name(trait_def_id)));
-                    p!("(");
-
-                    for (idx, ty) in args.iter().enumerate() {
-                        if idx > 0 {
-                            p!(", ");
+                self.wrap_binder(
+                    &bound_args_and_self_ty,
+                    WrapBinderMode::ForAll,
+                    |(args, _), cx| {
+                        define_scoped_cx!(cx);
+                        p!(write("{}", tcx.item_name(trait_def_id)));
+                        p!("(");
+
+                        for (idx, ty) in args.iter().enumerate() {
+                            if idx > 0 {
+                                p!(", ");
+                            }
+                            p!(print(ty));
                         }
-                        p!(print(ty));
-                    }
 
-                    p!(")");
-                    if let Some(ty) = return_ty.skip_binder().as_type() {
-                        if !ty.is_unit() {
-                            p!(" -> ", print(return_ty));
+                        p!(")");
+                        if let Some(ty) = return_ty.skip_binder().as_type() {
+                            if !ty.is_unit() {
+                                p!(" -> ", print(return_ty));
+                            }
                         }
-                    }
-                    p!(write("{}", if paren_needed { ")" } else { "" }));
+                        p!(write("{}", if paren_needed { ")" } else { "" }));
 
-                    first = false;
-                    Ok(())
-                })?;
+                        first = false;
+                        Ok(())
+                    },
+                )?;
             } else {
                 // Otherwise, render this like a regular trait.
                 traits.insert(
@@ -1110,7 +1134,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
         for (trait_pred, assoc_items) in traits {
             write!(self, "{}", if first { "" } else { " + " })?;
 
-            self.wrap_binder(&trait_pred, |trait_pred, cx| {
+            self.wrap_binder(&trait_pred, WrapBinderMode::ForAll, |trait_pred, cx| {
                 define_scoped_cx!(cx);
 
                 if trait_pred.polarity == ty::PredicatePolarity::Negative {
@@ -1302,7 +1326,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
         let mut first = true;
 
         if let Some(bound_principal) = predicates.principal() {
-            self.wrap_binder(&bound_principal, |principal, cx| {
+            self.wrap_binder(&bound_principal, WrapBinderMode::ForAll, |principal, cx| {
                 define_scoped_cx!(cx);
                 p!(print_def_path(principal.def_id, &[]));
 
@@ -1927,7 +1951,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
         let kind = closure.kind_ty().to_opt_closure_kind().unwrap_or(ty::ClosureKind::Fn);
 
         write!(self, "impl ")?;
-        self.wrap_binder(&sig, |sig, cx| {
+        self.wrap_binder(&sig, WrapBinderMode::ForAll, |sig, cx| {
             define_scoped_cx!(cx);
 
             p!(write("{kind}("));
@@ -2367,22 +2391,27 @@ impl<'tcx> PrettyPrinter<'tcx> for FmtPrinter<'_, 'tcx> {
         Ok(())
     }
 
-    fn in_binder<T>(&mut self, value: &ty::Binder<'tcx, T>) -> Result<(), PrintError>
+    fn in_binder<T>(
+        &mut self,
+        value: &ty::Binder<'tcx, T>,
+        mode: WrapBinderMode,
+    ) -> Result<(), PrintError>
     where
         T: Print<'tcx, Self> + TypeFoldable<TyCtxt<'tcx>>,
     {
-        self.pretty_in_binder(value)
+        self.pretty_in_binder(value, mode)
     }
 
     fn wrap_binder<T, C: FnOnce(&T, &mut Self) -> Result<(), PrintError>>(
         &mut self,
         value: &ty::Binder<'tcx, T>,
+        mode: WrapBinderMode,
         f: C,
     ) -> Result<(), PrintError>
     where
         T: TypeFoldable<TyCtxt<'tcx>>,
     {
-        self.pretty_wrap_binder(value, f)
+        self.pretty_wrap_binder(value, mode, f)
     }
 
     fn typed_value(
@@ -2632,6 +2661,7 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
     pub fn name_all_regions<T>(
         &mut self,
         value: &ty::Binder<'tcx, T>,
+        mode: WrapBinderMode,
     ) -> Result<(T, UnordMap<ty::BoundRegion, ty::Region<'tcx>>), fmt::Error>
     where
         T: TypeFoldable<TyCtxt<'tcx>>,
@@ -2705,9 +2735,13 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
         // anyways.
         let (new_value, map) = if self.should_print_verbose() {
             for var in value.bound_vars().iter() {
-                start_or_continue(self, "for<", ", ");
+                start_or_continue(self, mode.start_str(), ", ");
                 write!(self, "{var:?}")?;
             }
+            // Unconditionally render `unsafe<>`.
+            if value.bound_vars().is_empty() && mode == WrapBinderMode::Unsafe {
+                start_or_continue(self, mode.start_str(), "");
+            }
             start_or_continue(self, "", "> ");
             (value.clone().skip_binder(), UnordMap::default())
         } else {
@@ -2772,8 +2806,9 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
                     }
                 };
 
-                if !trim_path {
-                    start_or_continue(self, "for<", ", ");
+                // Unconditionally render `unsafe<>`.
+                if !trim_path || mode == WrapBinderMode::Unsafe {
+                    start_or_continue(self, mode.start_str(), ", ");
                     do_continue(self, name);
                 }
                 ty::Region::new_bound(tcx, ty::INNERMOST, ty::BoundRegion { var: br.var, kind })
@@ -2786,9 +2821,12 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
             };
             let new_value = value.clone().skip_binder().fold_with(&mut folder);
             let region_map = folder.region_map;
-            if !trim_path {
-                start_or_continue(self, "", "> ");
+
+            if mode == WrapBinderMode::Unsafe && region_map.is_empty() {
+                start_or_continue(self, mode.start_str(), "");
             }
+            start_or_continue(self, "", "> ");
+
             (new_value, region_map)
         };
 
@@ -2797,12 +2835,16 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
         Ok((new_value, map))
     }
 
-    pub fn pretty_in_binder<T>(&mut self, value: &ty::Binder<'tcx, T>) -> Result<(), fmt::Error>
+    pub fn pretty_in_binder<T>(
+        &mut self,
+        value: &ty::Binder<'tcx, T>,
+        mode: WrapBinderMode,
+    ) -> Result<(), fmt::Error>
     where
         T: Print<'tcx, Self> + TypeFoldable<TyCtxt<'tcx>>,
     {
         let old_region_index = self.region_index;
-        let (new_value, _) = self.name_all_regions(value)?;
+        let (new_value, _) = self.name_all_regions(value, mode)?;
         new_value.print(self)?;
         self.region_index = old_region_index;
         self.binder_depth -= 1;
@@ -2812,13 +2854,14 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
     pub fn pretty_wrap_binder<T, C: FnOnce(&T, &mut Self) -> Result<(), fmt::Error>>(
         &mut self,
         value: &ty::Binder<'tcx, T>,
+        mode: WrapBinderMode,
         f: C,
     ) -> Result<(), fmt::Error>
     where
         T: TypeFoldable<TyCtxt<'tcx>>,
     {
         let old_region_index = self.region_index;
-        let (new_value, _) = self.name_all_regions(value)?;
+        let (new_value, _) = self.name_all_regions(value, mode)?;
         f(&new_value, self)?;
         self.region_index = old_region_index;
         self.binder_depth -= 1;
@@ -2877,7 +2920,7 @@ where
     T: Print<'tcx, P> + TypeFoldable<TyCtxt<'tcx>>,
 {
     fn print(&self, cx: &mut P) -> Result<(), PrintError> {
-        cx.in_binder(self)
+        cx.in_binder(self, WrapBinderMode::ForAll)
     }
 }
 
diff --git a/compiler/rustc_symbol_mangling/src/v0.rs b/compiler/rustc_symbol_mangling/src/v0.rs
index 4fafd1ac350..5f56e3d2145 100644
--- a/compiler/rustc_symbol_mangling/src/v0.rs
+++ b/compiler/rustc_symbol_mangling/src/v0.rs
@@ -12,7 +12,7 @@ use rustc_hir::def_id::{CrateNum, DefId};
 use rustc_hir::definitions::{DefPathData, DisambiguatedDefPathData};
 use rustc_middle::bug;
 use rustc_middle::ty::layout::IntegerExt;
-use rustc_middle::ty::print::{Print, PrintError, Printer};
+use rustc_middle::ty::print::{Print, PrintError, Printer, WrapBinderMode};
 use rustc_middle::ty::{
     self, FloatTy, GenericArg, GenericArgKind, Instance, IntTy, ReifyReason, Ty, TyCtxt,
     TypeVisitable, TypeVisitableExt, UintTy,
@@ -172,6 +172,7 @@ impl<'tcx> SymbolMangler<'tcx> {
     fn in_binder<T>(
         &mut self,
         value: &ty::Binder<'tcx, T>,
+        _mode: WrapBinderMode,
         print_value: impl FnOnce(&mut Self, &T) -> Result<(), PrintError>,
     ) -> Result<(), PrintError>
     where
@@ -471,7 +472,7 @@ impl<'tcx> Printer<'tcx> for SymbolMangler<'tcx> {
             ty::FnPtr(sig_tys, hdr) => {
                 let sig = sig_tys.with(hdr);
                 self.push("F");
-                self.in_binder(&sig, |cx, sig| {
+                self.in_binder(&sig, WrapBinderMode::ForAll, |cx, sig| {
                     if sig.safety.is_unsafe() {
                         cx.push("U");
                     }
@@ -554,7 +555,7 @@ impl<'tcx> Printer<'tcx> for SymbolMangler<'tcx> {
         // [<Trait> [{<Projection>}]] [{<Auto>}]
         // Since any predicates after the first one shouldn't change the binders,
         // just put them all in the binders of the first.
-        self.in_binder(&predicates[0], |cx, _| {
+        self.in_binder(&predicates[0], WrapBinderMode::ForAll, |cx, _| {
             for predicate in predicates.iter() {
                 // It would be nice to be able to validate bound vars here, but
                 // projections can actually include bound vars from super traits
diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
index 847bd06bb01..42d37418fb8 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
@@ -65,7 +65,9 @@ use rustc_middle::bug;
 use rustc_middle::dep_graph::DepContext;
 use rustc_middle::traits::PatternOriginExpr;
 use rustc_middle::ty::error::{ExpectedFound, TypeError, TypeErrorToStringExt};
-use rustc_middle::ty::print::{PrintError, PrintTraitRefExt as _, with_forced_trimmed_paths};
+use rustc_middle::ty::print::{
+    PrintError, PrintTraitRefExt as _, WrapBinderMode, with_forced_trimmed_paths,
+};
 use rustc_middle::ty::{
     self, List, ParamEnv, Region, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable, TypeVisitable,
     TypeVisitableExt,
@@ -835,7 +837,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
         let get_lifetimes = |sig| {
             use rustc_hir::def::Namespace;
             let (sig, reg) = ty::print::FmtPrinter::new(self.tcx, Namespace::TypeNS)
-                .name_all_regions(sig)
+                .name_all_regions(sig, WrapBinderMode::ForAll)
                 .unwrap();
             let lts: Vec<String> =
                 reg.into_items().map(|(_, kind)| kind.to_string()).into_sorted_stable_ord();
diff --git a/tests/ui/unsafe-binders/type-mismatch.rs b/tests/ui/unsafe-binders/type-mismatch.rs
new file mode 100644
index 00000000000..9ac4e817c28
--- /dev/null
+++ b/tests/ui/unsafe-binders/type-mismatch.rs
@@ -0,0 +1,9 @@
+#![feature(unsafe_binders)]
+//~^ WARN the feature `unsafe_binders` is incomplete
+
+fn main() {
+    let x: unsafe<> i32 = 0;
+    //~^ ERROR mismatched types
+    let x: unsafe<'a> &'a i32 = &0;
+    //~^ ERROR mismatched types
+}
diff --git a/tests/ui/unsafe-binders/type-mismatch.stderr b/tests/ui/unsafe-binders/type-mismatch.stderr
new file mode 100644
index 00000000000..e694b5d464d
--- /dev/null
+++ b/tests/ui/unsafe-binders/type-mismatch.stderr
@@ -0,0 +1,34 @@
+warning: the feature `unsafe_binders` is incomplete and may not be safe to use and/or cause compiler crashes
+  --> $DIR/type-mismatch.rs:1:12
+   |
+LL | #![feature(unsafe_binders)]
+   |            ^^^^^^^^^^^^^^
+   |
+   = note: see issue #130516 <https://github.com/rust-lang/rust/issues/130516> for more information
+   = note: `#[warn(incomplete_features)]` on by default
+
+error[E0308]: mismatched types
+  --> $DIR/type-mismatch.rs:5:27
+   |
+LL |     let x: unsafe<> i32 = 0;
+   |            ------------   ^ expected `unsafe<> i32`, found integer
+   |            |
+   |            expected due to this
+   |
+   = note: expected unsafe binder `unsafe<> i32`
+                       found type `{integer}`
+
+error[E0308]: mismatched types
+  --> $DIR/type-mismatch.rs:7:33
+   |
+LL |     let x: unsafe<'a> &'a i32 = &0;
+   |            ------------------   ^^ expected `unsafe<'a> &i32`, found `&{integer}`
+   |            |
+   |            expected due to this
+   |
+   = note: expected unsafe binder `unsafe<'a> &'a i32`
+                  found reference `&{integer}`
+
+error: aborting due to 2 previous errors; 1 warning emitted
+
+For more information about this error, try `rustc --explain E0308`.