summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_const_eval/src/interpret/call.rs53
-rw-r--r--src/tools/miri/tests/pass/function_calls/abi_compat.rs24
2 files changed, 59 insertions, 18 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/call.rs b/compiler/rustc_const_eval/src/interpret/call.rs
index 71ec9c2ab19..ef0902e4226 100644
--- a/compiler/rustc_const_eval/src/interpret/call.rs
+++ b/compiler/rustc_const_eval/src/interpret/call.rs
@@ -4,9 +4,9 @@ use std::assert_matches::assert_matches;
 use std::borrow::Cow;
 
 use either::{Left, Right};
-use rustc_abi::{self as abi, ExternAbi, FieldIdx, Integer};
+use rustc_abi::{self as abi, ExternAbi, FieldIdx, Integer, VariantIdx};
 use rustc_middle::ty::layout::{FnAbiOf, IntegerExt, LayoutOf, TyAndLayout};
-use rustc_middle::ty::{self, AdtDef, Instance, Ty};
+use rustc_middle::ty::{self, AdtDef, Instance, Ty, VariantDef};
 use rustc_middle::{bug, mir, span_bug};
 use rustc_span::sym;
 use rustc_target::callconv::{ArgAbi, FnAbi, PassMode};
@@ -92,29 +92,46 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
 
     /// Unwrap types that are guaranteed a null-pointer-optimization
     fn unfold_npo(&self, layout: TyAndLayout<'tcx>) -> InterpResult<'tcx, TyAndLayout<'tcx>> {
-        // Check if this is `Option` wrapping some type or if this is `Result` wrapping a 1-ZST and
-        // another type.
+        // Check if this is an option-like type wrapping some type.
         let ty::Adt(def, args) = layout.ty.kind() else {
             // Not an ADT, so definitely no NPO.
             return interp_ok(layout);
         };
-        let inner = if self.tcx.is_diagnostic_item(sym::Option, def.did()) {
-            // The wrapped type is the only arg.
-            self.layout_of(args[0].as_type().unwrap())?
-        } else if self.tcx.is_diagnostic_item(sym::Result, def.did()) {
-            // We want to extract which (if any) of the args is not a 1-ZST.
-            let lhs = self.layout_of(args[0].as_type().unwrap())?;
-            let rhs = self.layout_of(args[1].as_type().unwrap())?;
-            if lhs.is_1zst() {
-                rhs
-            } else if rhs.is_1zst() {
-                lhs
-            } else {
-                return interp_ok(layout); // no NPO
+        if def.variants().len() != 2 {
+            // Not a 2-variant enum, so no NPO.
+            return interp_ok(layout);
+        }
+        assert!(def.is_enum());
+
+        let all_fields_1zst = |variant: &VariantDef| -> InterpResult<'tcx, _> {
+            for field in &variant.fields {
+                let ty = field.ty(*self.tcx, args);
+                let layout = self.layout_of(ty)?;
+                if !layout.is_1zst() {
+                    return interp_ok(false);
+                }
             }
+            interp_ok(true)
+        };
+
+        // If one variant consists entirely of 1-ZST, then the other variant
+        // is the only "relevant" one for this check.
+        let var0 = VariantIdx::from_u32(0);
+        let var1 = VariantIdx::from_u32(1);
+        let relevant_variant = if all_fields_1zst(def.variant(var0))? {
+            def.variant(var1)
+        } else if all_fields_1zst(def.variant(var1))? {
+            def.variant(var0)
         } else {
-            return interp_ok(layout); // no NPO
+            // No varant is all-1-ZST, so no NPO.
+            return interp_ok(layout);
         };
+        // The "relevant" variant must have exactly one field, and its type is the "inner" type.
+        if relevant_variant.fields.len() != 1 {
+            return interp_ok(layout);
+        }
+        let inner = relevant_variant.fields[FieldIdx::from_u32(0)].ty(*self.tcx, args);
+        let inner = self.layout_of(inner)?;
 
         // Check if the inner type is one of the NPO-guaranteed ones.
         // For that we first unpeel transparent *structs* (but not unions).
diff --git a/src/tools/miri/tests/pass/function_calls/abi_compat.rs b/src/tools/miri/tests/pass/function_calls/abi_compat.rs
index b5feac8c677..cd48bd2accb 100644
--- a/src/tools/miri/tests/pass/function_calls/abi_compat.rs
+++ b/src/tools/miri/tests/pass/function_calls/abi_compat.rs
@@ -1,3 +1,5 @@
+#![feature(never_type)]
+
 use std::rc::Rc;
 use std::{mem, num, ptr};
 
@@ -12,6 +14,18 @@ fn id<T>(x: T) -> T {
     x
 }
 
+#[derive(Copy, Clone)]
+enum Either<T, U> {
+    Left(T),
+    Right(U),
+}
+#[derive(Copy, Clone)]
+enum Either2<T, U> {
+    Left(T),
+    #[allow(unused)]
+    Right(U, ()),
+}
+
 fn test_abi_compat<T: Clone, U: Clone>(t: T, u: U) {
     fn id<T>(x: T) -> T {
         x
@@ -81,6 +95,8 @@ fn main() {
     test_abi_compat(main as fn(), id::<i32> as fn(i32) -> i32);
     // - 1-ZST
     test_abi_compat((), [0u8; 0]);
+
+    // Guaranteed null-pointer-layout optimizations:
     // - Guaranteed Option<X> null-pointer-optimizations (RFC 3391).
     test_abi_compat(&0u32 as *const u32, Some(&0u32));
     test_abi_compat(main as fn(), Some(main as fn()));
@@ -89,6 +105,7 @@ fn main() {
     test_abi_compat(0u32, Some(Wrapper(num::NonZeroU32::new(1u32).unwrap())));
     // - Guaranteed Result<X, ZST1> does the same as Option<X> (RFC 3391)
     test_abi_compat(&0u32 as *const u32, Result::<_, ()>::Ok(&0u32));
+    test_abi_compat(&0u32 as *const u32, Result::<_, !>::Ok(&0u32));
     test_abi_compat(main as fn(), Result::<_, ()>::Ok(main as fn()));
     test_abi_compat(0u32, Result::<_, ()>::Ok(num::NonZeroU32::new(1).unwrap()));
     test_abi_compat(&0u32 as *const u32, Result::<_, ()>::Ok(Wrapper(&0u32)));
@@ -99,6 +116,13 @@ fn main() {
     test_abi_compat(0u32, Result::<(), _>::Err(num::NonZeroU32::new(1).unwrap()));
     test_abi_compat(&0u32 as *const u32, Result::<(), _>::Err(Wrapper(&0u32)));
     test_abi_compat(0u32, Result::<(), _>::Err(Wrapper(num::NonZeroU32::new(1).unwrap())));
+    // - Guaranteed null-pointer-optimizations for custom option-like types
+    test_abi_compat(&0u32 as *const u32, Either::<_, ()>::Left(&0u32));
+    test_abi_compat(&0u32 as *const u32, Either::<_, !>::Left(&0u32));
+    test_abi_compat(&0u32 as *const u32, Either::<(), _>::Right(&0u32));
+    test_abi_compat(&0u32 as *const u32, Either::<!, _>::Right(&0u32));
+    test_abi_compat(&0u32 as *const u32, Either2::<_, ()>::Left(&0u32));
+    test_abi_compat(&0u32 as *const u32, Either2::<_, [u8; 0]>::Left(&0u32));
 
     // These must work for *any* type, since we guarantee that `repr(transparent)` is ABI-compatible
     // with the wrapped field.