Pretty type names in graphviz dump, bring back PartialEq/NodeValue impl

This commit is contained in:
Shadowfacts 2024-11-05 11:19:23 -05:00
parent b8ad929d0b
commit 8f0fe08ecc
3 changed files with 58 additions and 36 deletions

View File

@ -8,6 +8,8 @@ edition = "2021"
[dependencies] [dependencies]
compute_graph_macros = { path = "../compute_graph_macros" } compute_graph_macros = { path = "../compute_graph_macros" }
petgraph = "0.6.5" petgraph = "0.6.5"
syn = "2"
quote = "1"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.41.0", features = ["rt", "macros"] } tokio = { version = "1.41.0", features = ["rt", "macros"] }

View File

@ -618,11 +618,6 @@ mod tests {
fn non_cloneable_output() { fn non_cloneable_output() {
#[derive(PartialEq, Debug)] #[derive(PartialEq, Debug)]
struct NonCloneable; struct NonCloneable;
impl NodeValue for NonCloneable {
fn node_value_eq(&self, _other: &Self) -> bool {
true
}
}
struct Output; struct Output;
impl InputVisitable for Output { impl InputVisitable for Output {
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}

View File

@ -1,5 +1,6 @@
use crate::synchronicity::{Asynchronous, Synchronicity}; use crate::synchronicity::{Asynchronous, Synchronicity};
use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous}; use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous};
use quote::ToTokens;
use std::any::Any; use std::any::Any;
use std::cell::RefCell; use std::cell::RefCell;
use std::future::Future; use std::future::Future;
@ -85,13 +86,16 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
/// ///
/// This trait is used to determine, when a node is invalidated, whether its value has truly changed /// This trait is used to determine, when a node is invalidated, whether its value has truly changed
/// and thus whether downstream nodes need to be invalidated too. /// and thus whether downstream nodes need to be invalidated too.
///
/// A blanket implementation of this trait for all types implementing `PartialEq` is provided.
pub trait NodeValue: 'static { pub trait NodeValue: 'static {
/// Whether self is equal, for the purposes of graph invalidation, from other. /// Whether self is equal, for the purposes of graph invalidation, from other.
/// ///
/// This method should be conservative. That is, if the equality of the two values cannot be affirmatively /// This method should be conservative. That is, if the equality of the two values cannot be affirmatively
/// determined, this method should return `false`. /// determined, this method should return `false`.
/// ///
/// The default implementation of this method always returns `false`. /// The default implementation of this method always returns `false`, so any non-`PartialEq` type can
/// implement this trait simply:
/// ///
/// ```rust /// ```rust
/// # use compute_graph::node::NodeValue; /// # use compute_graph::node::NodeValue;
@ -106,36 +110,12 @@ pub trait NodeValue: 'static {
} }
} }
impl NodeValue for () { impl<T: PartialEq + 'static> NodeValue for T {
fn node_value_eq(&self, _other: &Self) -> bool {
true
}
}
impl NodeValue for i32 {
fn node_value_eq(&self, other: &Self) -> bool { fn node_value_eq(&self, other: &Self) -> bool {
self == other self == other
} }
} }
impl<T: NodeValue> NodeValue for Option<T> {
fn node_value_eq(&self, other: &Self) -> bool {
match (self, other) {
(Some(s), Some(o)) => s.node_value_eq(o),
_ => false,
}
}
}
impl<T: NodeValue, E: 'static> NodeValue for Result<T, E> {
fn node_value_eq(&self, other: &Self) -> bool {
match (self, other) {
(Ok(s), Ok(o)) => s.node_value_eq(o),
_ => false,
}
}
}
pub(crate) struct ConstNode<V, S> { pub(crate) struct ConstNode<V, S> {
value: Rc<RefCell<Option<V>>>, value: Rc<RefCell<Option<V>>>,
synchronicity: std::marker::PhantomData<S>, synchronicity: std::marker::PhantomData<S>,
@ -170,7 +150,7 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
impl<V, S> std::fmt::Debug for ConstNode<V, S> { impl<V, S> std::fmt::Debug for ConstNode<V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ConstNode<{}>", std::any::type_name::<V>()) write!(f, "ConstNode<{}>", pretty_type_name::<V>())
} }
} }
@ -215,7 +195,7 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for InvalidatableConstNode<V, S>
impl<V, S> std::fmt::Debug for InvalidatableConstNode<V, S> { impl<V, S> std::fmt::Debug for InvalidatableConstNode<V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "InvalidatableConstNode<{}>", std::any::type_name::<V>()) write!(f, "InvalidatableConstNode<{}>", pretty_type_name::<V>())
} }
} }
@ -289,7 +269,7 @@ impl<R: Rule, V, S> std::fmt::Debug for RuleNode<R, V, S> {
write!( write!(
f, f,
"RuleNode<{}>({})", "RuleNode<{}>({})",
std::any::type_name::<R>(), pretty_type_name::<R>(),
RuleLabel(&self.rule) RuleLabel(&self.rule)
) )
} }
@ -343,7 +323,7 @@ impl<V: NodeValue, P: FnOnce() -> F, F: Future<Output = V>> Node<V, Asynchronous
impl<V, P: FnOnce() -> F, F: Future<Output = V>> std::fmt::Debug for AsyncConstNode<V, P, F> { impl<V, P: FnOnce() -> F, F: Future<Output = V>> std::fmt::Debug for AsyncConstNode<V, P, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "AsyncConstNode<{}>", std::any::type_name::<V>()) write!(f, "AsyncConstNode<{}>", pretty_type_name::<V>())
} }
} }
@ -419,8 +399,53 @@ impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
write!( write!(
f, f,
"AsyncRuleNode<{}>({})", "AsyncRuleNode<{}>({})",
std::any::type_name::<R>(), pretty_type_name::<R>(),
AsyncRuleLabel(&self.rule) AsyncRuleLabel(&self.rule)
) )
} }
} }
fn pretty_type_name<T>() -> String {
let s = std::any::type_name::<T>();
let ty = syn::parse_str::<syn::Type>(s).unwrap();
pretty_type_name_type(ty)
}
fn pretty_type_name_type(ty: syn::Type) -> String {
match ty {
syn::Type::Path(path) => pretty_type_name_path(path),
_ => format!("{}", ty.into_token_stream()),
}
}
fn pretty_type_name_path(path: syn::TypePath) -> String {
if path.qself.is_some() {
format!("{}", path.into_token_stream())
} else {
let last_segment = path.path.segments.last().unwrap();
match &last_segment.arguments {
syn::PathArguments::None => {
format!("{}", last_segment.ident.to_token_stream())
}
syn::PathArguments::AngleBracketed(args) => {
let mut str = format!("{}", last_segment.ident.to_token_stream());
str.push('<');
for arg in &args.args {
match arg {
syn::GenericArgument::Type(ty) => {
str.push_str(&pretty_type_name_type(ty.clone()))
}
_ => str.push_str(&format!("{}", arg.into_token_stream())),
}
str.push_str(", ")
}
str.remove(str.len() - 1);
str.replace_range((str.len() - 1).., ">");
str
}
syn::PathArguments::Parenthesized(_) => {
format!("{}", last_segment.into_token_stream())
}
}
}
}