diff --git a/crates/compute_graph/Cargo.toml b/crates/compute_graph/Cargo.toml index f6fd4fb..0d295ae 100644 --- a/crates/compute_graph/Cargo.toml +++ b/crates/compute_graph/Cargo.toml @@ -8,6 +8,8 @@ edition = "2021" [dependencies] compute_graph_macros = { path = "../compute_graph_macros" } petgraph = "0.6.5" +syn = "2" +quote = "1" [dev-dependencies] tokio = { version = "1.41.0", features = ["rt", "macros"] } diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index 6edee7f..3f744c4 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -618,11 +618,6 @@ mod tests { fn non_cloneable_output() { #[derive(PartialEq, Debug)] struct NonCloneable; - impl NodeValue for NonCloneable { - fn node_value_eq(&self, _other: &Self) -> bool { - true - } - } struct Output; impl InputVisitable for Output { fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} diff --git a/crates/compute_graph/src/node.rs b/crates/compute_graph/src/node.rs index a881f58..be91a8c 100644 --- a/crates/compute_graph/src/node.rs +++ b/crates/compute_graph/src/node.rs @@ -1,5 +1,6 @@ use crate::synchronicity::{Asynchronous, Synchronicity}; use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous}; +use quote::ToTokens; use std::any::Any; use std::cell::RefCell; use std::future::Future; @@ -85,13 +86,16 @@ pub(crate) trait Node: std::fmt::Debug { /// /// This trait is used to determine, when a node is invalidated, whether its value has truly changed /// and thus whether downstream nodes need to be invalidated too. +/// +/// A blanket implementation of this trait for all types implementing `PartialEq` is provided. pub trait NodeValue: 'static { /// Whether self is equal, for the purposes of graph invalidation, from other. /// /// This method should be conservative. That is, if the equality of the two values cannot be affirmatively /// determined, this method should return `false`. /// - /// The default implementation of this method always returns `false`. + /// The default implementation of this method always returns `false`, so any non-`PartialEq` type can + /// implement this trait simply: /// /// ```rust /// # use compute_graph::node::NodeValue; @@ -106,36 +110,12 @@ pub trait NodeValue: 'static { } } -impl NodeValue for () { - fn node_value_eq(&self, _other: &Self) -> bool { - true - } -} - -impl NodeValue for i32 { +impl NodeValue for T { fn node_value_eq(&self, other: &Self) -> bool { self == other } } -impl NodeValue for Option { - fn node_value_eq(&self, other: &Self) -> bool { - match (self, other) { - (Some(s), Some(o)) => s.node_value_eq(o), - _ => false, - } - } -} - -impl NodeValue for Result { - fn node_value_eq(&self, other: &Self) -> bool { - match (self, other) { - (Ok(s), Ok(o)) => s.node_value_eq(o), - _ => false, - } - } -} - pub(crate) struct ConstNode { value: Rc>>, synchronicity: std::marker::PhantomData, @@ -170,7 +150,7 @@ impl Node for ConstNode { impl std::fmt::Debug for ConstNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ConstNode<{}>", std::any::type_name::()) + write!(f, "ConstNode<{}>", pretty_type_name::()) } } @@ -215,7 +195,7 @@ impl Node for InvalidatableConstNode impl std::fmt::Debug for InvalidatableConstNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "InvalidatableConstNode<{}>", std::any::type_name::()) + write!(f, "InvalidatableConstNode<{}>", pretty_type_name::()) } } @@ -289,7 +269,7 @@ impl std::fmt::Debug for RuleNode { write!( f, "RuleNode<{}>({})", - std::any::type_name::(), + pretty_type_name::(), RuleLabel(&self.rule) ) } @@ -343,7 +323,7 @@ impl F, F: Future> Node F, F: Future> std::fmt::Debug for AsyncConstNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "AsyncConstNode<{}>", std::any::type_name::()) + write!(f, "AsyncConstNode<{}>", pretty_type_name::()) } } @@ -419,8 +399,53 @@ impl std::fmt::Debug for AsyncRuleNode { write!( f, "AsyncRuleNode<{}>({})", - std::any::type_name::(), + pretty_type_name::(), AsyncRuleLabel(&self.rule) ) } } + +fn pretty_type_name() -> String { + let s = std::any::type_name::(); + let ty = syn::parse_str::(s).unwrap(); + pretty_type_name_type(ty) +} + +fn pretty_type_name_type(ty: syn::Type) -> String { + match ty { + syn::Type::Path(path) => pretty_type_name_path(path), + _ => format!("{}", ty.into_token_stream()), + } +} + +fn pretty_type_name_path(path: syn::TypePath) -> String { + if path.qself.is_some() { + format!("{}", path.into_token_stream()) + } else { + let last_segment = path.path.segments.last().unwrap(); + match &last_segment.arguments { + syn::PathArguments::None => { + format!("{}", last_segment.ident.to_token_stream()) + } + syn::PathArguments::AngleBracketed(args) => { + let mut str = format!("{}", last_segment.ident.to_token_stream()); + str.push('<'); + for arg in &args.args { + match arg { + syn::GenericArgument::Type(ty) => { + str.push_str(&pretty_type_name_type(ty.clone())) + } + _ => str.push_str(&format!("{}", arg.into_token_stream())), + } + str.push_str(", ") + } + str.remove(str.len() - 1); + str.replace_range((str.len() - 1).., ">"); + str + } + syn::PathArguments::Parenthesized(_) => { + format!("{}", last_segment.into_token_stream()) + } + } + } +}