about summary refs log tree commit diff
path: root/compiler/rustc_ast/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_ast/src')
-rw-r--r--compiler/rustc_ast/src/ast.rs3
-rw-r--r--compiler/rustc_ast/src/expand/autodiff_attrs.rs17
-rw-r--r--compiler/rustc_ast/src/expand/typetree.rs1
-rw-r--r--compiler/rustc_ast/src/lib.rs1
-rw-r--r--compiler/rustc_ast/src/tokenstream.rs4
5 files changed, 18 insertions, 8 deletions
diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs
index 3e8fddd9954..082d5e88ac7 100644
--- a/compiler/rustc_ast/src/ast.rs
+++ b/compiler/rustc_ast/src/ast.rs
@@ -114,8 +114,7 @@ impl PartialEq<Symbol> for Path {
 impl PartialEq<&[Symbol]> for Path {
     #[inline]
     fn eq(&self, names: &&[Symbol]) -> bool {
-        self.segments.len() == names.len()
-            && self.segments.iter().zip(names.iter()).all(|(s1, s2)| s1 == s2)
+        self.segments.iter().eq(*names)
     }
 }
 
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
index 33451f99748..90f15753e99 100644
--- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs
+++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
@@ -6,6 +6,7 @@
 use std::fmt::{self, Display, Formatter};
 use std::str::FromStr;
 
+use crate::expand::typetree::TypeTree;
 use crate::expand::{Decodable, Encodable, HashStable_Generic};
 use crate::{Ty, TyKind};
 
@@ -84,6 +85,8 @@ pub struct AutoDiffItem {
     /// The name of the function being generated
     pub target: String,
     pub attrs: AutoDiffAttrs,
+    pub inputs: Vec<TypeTree>,
+    pub output: TypeTree,
 }
 
 #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -275,14 +278,22 @@ impl AutoDiffAttrs {
         !matches!(self.mode, DiffMode::Error | DiffMode::Source)
     }
 
-    pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
-        AutoDiffItem { source, target, attrs: self }
+    pub fn into_item(
+        self,
+        source: String,
+        target: String,
+        inputs: Vec<TypeTree>,
+        output: TypeTree,
+    ) -> AutoDiffItem {
+        AutoDiffItem { source, target, inputs, output, attrs: self }
     }
 }
 
 impl fmt::Display for AutoDiffItem {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         write!(f, "Differentiating {} -> {}", self.source, self.target)?;
-        write!(f, " with attributes: {:?}", self.attrs)
+        write!(f, " with attributes: {:?}", self.attrs)?;
+        write!(f, " with inputs: {:?}", self.inputs)?;
+        write!(f, " with output: {:?}", self.output)
     }
 }
diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs
index 9a2dd2e85e0..e7b4f3aff41 100644
--- a/compiler/rustc_ast/src/expand/typetree.rs
+++ b/compiler/rustc_ast/src/expand/typetree.rs
@@ -31,6 +31,7 @@ pub enum Kind {
     Half,
     Float,
     Double,
+    F128,
     Unknown,
 }
 
diff --git a/compiler/rustc_ast/src/lib.rs b/compiler/rustc_ast/src/lib.rs
index f1951049b47..5fe218776e5 100644
--- a/compiler/rustc_ast/src/lib.rs
+++ b/compiler/rustc_ast/src/lib.rs
@@ -15,6 +15,7 @@
 #![feature(associated_type_defaults)]
 #![feature(box_patterns)]
 #![feature(if_let_guard)]
+#![feature(iter_order_by)]
 #![feature(macro_metavar_expr)]
 #![feature(rustdoc_internals)]
 #![recursion_limit = "256"]
diff --git a/compiler/rustc_ast/src/tokenstream.rs b/compiler/rustc_ast/src/tokenstream.rs
index a5d8fbfac61..4111182c3b7 100644
--- a/compiler/rustc_ast/src/tokenstream.rs
+++ b/compiler/rustc_ast/src/tokenstream.rs
@@ -48,9 +48,7 @@ impl TokenTree {
         match (self, other) {
             (TokenTree::Token(token, _), TokenTree::Token(token2, _)) => token.kind == token2.kind,
             (TokenTree::Delimited(.., delim, tts), TokenTree::Delimited(.., delim2, tts2)) => {
-                delim == delim2
-                    && tts.len() == tts2.len()
-                    && tts.iter().zip(tts2.iter()).all(|(a, b)| a.eq_unspanned(b))
+                delim == delim2 && tts.iter().eq_by(tts2.iter(), |a, b| a.eq_unspanned(b))
             }
             _ => false,
         }