diff --git a/crates/compute_graph/src/builder.rs b/crates/compute_graph/src/builder.rs index 1ceb63b..8cd77bb 100644 --- a/crates/compute_graph/src/builder.rs +++ b/crates/compute_graph/src/builder.rs @@ -333,6 +333,14 @@ pub enum BuildGraphError { Cycle(Vec), } +impl std::fmt::Display for BuildGraphError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl std::error::Error for BuildGraphError {} + #[cfg(test)] mod tests { use crate::tests::{DeferredInput, Double}; diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index f81d0f6..6edee7f 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -618,6 +618,11 @@ mod tests { fn non_cloneable_output() { #[derive(PartialEq, Debug)] struct NonCloneable; + impl NodeValue for NonCloneable { + fn node_value_eq(&self, _other: &Self) -> bool { + true + } + } struct Output; impl InputVisitable for Output { fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} @@ -689,7 +694,23 @@ mod tests { let mut builder = GraphBuilder::new(); let a = builder.add_value(1); let b = builder.add_value(2); - builder.set_output(Add(a, b)); + struct AddWithLabel(Input, Input); + impl InputVisitable for AddWithLabel { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit(&self.0); + visitor.visit(&self.1); + } + } + impl Rule for AddWithLabel { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + *self.0.value() + *self.1.value() + } + fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "test") + } + } + builder.set_output(AddWithLabel(a, b)); let graph = builder.build().unwrap(); println!("{}", graph.as_dot_string()); assert_eq!( @@ -697,7 +718,7 @@ mod tests { r#"digraph { 0 [label ="ConstNode (id=0)"] 1 [label ="ConstNode (id=1)"] - 2 [label ="RuleNode (id=2)"] + 2 [label ="RuleNode(test) (id=2)"] 0 -> 2 [] 1 -> 2 [] } diff --git a/crates/compute_graph/src/node.rs b/crates/compute_graph/src/node.rs index 32f54b0..a881f58 100644 --- a/crates/compute_graph/src/node.rs +++ b/crates/compute_graph/src/node.rs @@ -85,16 +85,13 @@ pub(crate) trait Node: std::fmt::Debug { /// /// This trait is used to determine, when a node is invalidated, whether its value has truly changed /// and thus whether downstream nodes need to be invalidated too. -/// -/// A blanket implementation of this trait for all types implementing `PartialEq` is provided. pub trait NodeValue: 'static { /// Whether self is equal, for the purposes of graph invalidation, from other. /// /// This method should be conservative. That is, if the equality of the two values cannot be affirmatively /// determined, this method should return `false`. /// - /// The default implementation of this method always returns `false`, so any non-`PartialEq` type can - /// implement this trait simply: + /// The default implementation of this method always returns `false`. /// /// ```rust /// # use compute_graph::node::NodeValue; @@ -109,12 +106,36 @@ pub trait NodeValue: 'static { } } -impl NodeValue for T { +impl NodeValue for () { + fn node_value_eq(&self, _other: &Self) -> bool { + true + } +} + +impl NodeValue for i32 { fn node_value_eq(&self, other: &Self) -> bool { self == other } } +impl NodeValue for Option { + fn node_value_eq(&self, other: &Self) -> bool { + match (self, other) { + (Some(s), Some(o)) => s.node_value_eq(o), + _ => false, + } + } +} + +impl NodeValue for Result { + fn node_value_eq(&self, other: &Self) -> bool { + match (self, other) { + (Ok(s), Ok(o)) => s.node_value_eq(o), + _ => false, + } + } +} + pub(crate) struct ConstNode { value: Rc>>, synchronicity: std::marker::PhantomData, @@ -256,9 +277,21 @@ impl Node for RuleNode } } -impl std::fmt::Debug for RuleNode { +struct RuleLabel<'a, R: Rule>(&'a R); +impl<'a, R: Rule> std::fmt::Display for RuleLabel<'a, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "RuleNode<{}>", std::any::type_name::()) + self.0.node_label(f) + } +} + +impl std::fmt::Debug for RuleNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RuleNode<{}>({})", + std::any::type_name::(), + RuleLabel(&self.rule) + ) } } @@ -374,8 +407,20 @@ impl Node for AsyncRuleNode } } -impl std::fmt::Debug for AsyncRuleNode { +struct AsyncRuleLabel<'a, R: AsyncRule>(&'a R); +impl<'a, R: AsyncRule> std::fmt::Display for AsyncRuleLabel<'a, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "AsyncRuleNode<{}>", std::any::type_name::()) + self.0.node_label(f) + } +} + +impl std::fmt::Debug for AsyncRuleNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "AsyncRuleNode<{}>({})", + std::any::type_name::(), + AsyncRuleLabel(&self.rule) + ) } } diff --git a/crates/compute_graph/src/rule.rs b/crates/compute_graph/src/rule.rs index 6674b1d..699b410 100644 --- a/crates/compute_graph/src/rule.rs +++ b/crates/compute_graph/src/rule.rs @@ -38,6 +38,11 @@ pub trait Rule: InputVisitable + 'static { /// 2. A rule will never be evaluated before _all_ of its dependencies up-to-date. That is, it will never /// be evaluated with mix of valid and invalid dependencies. fn evaluate(&mut self) -> Self::Output; + + #[allow(unused_variables)] + fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + Ok(()) + } } /// A rule produces a value for a graph node asynchronously. @@ -64,6 +69,11 @@ pub trait AsyncRule: InputVisitable + 'static { /// /// See [`Rule::evaluate`] for additional details; the same considerations apply. fn evaluate(&mut self) -> impl Future + '_; + + #[allow(unused_variables)] + fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + Ok(()) + } } /// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.