Assorted graph tweaks

This commit is contained in:
Shadowfacts 2024-11-03 14:42:45 -05:00
parent 712b528ca8
commit e69014d98d
4 changed files with 95 additions and 11 deletions

View File

@ -333,6 +333,14 @@ pub enum BuildGraphError {
Cycle(Vec<NodeId>), Cycle(Vec<NodeId>),
} }
impl std::fmt::Display for BuildGraphError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for BuildGraphError {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::tests::{DeferredInput, Double}; use crate::tests::{DeferredInput, Double};

View File

@ -618,6 +618,11 @@ mod tests {
fn non_cloneable_output() { fn non_cloneable_output() {
#[derive(PartialEq, Debug)] #[derive(PartialEq, Debug)]
struct NonCloneable; struct NonCloneable;
impl NodeValue for NonCloneable {
fn node_value_eq(&self, _other: &Self) -> bool {
true
}
}
struct Output; struct Output;
impl InputVisitable for Output { impl InputVisitable for Output {
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
@ -689,7 +694,23 @@ mod tests {
let mut builder = GraphBuilder::new(); let mut builder = GraphBuilder::new();
let a = builder.add_value(1); let a = builder.add_value(1);
let b = builder.add_value(2); let b = builder.add_value(2);
builder.set_output(Add(a, b)); struct AddWithLabel(Input<i32>, Input<i32>);
impl InputVisitable for AddWithLabel {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
visitor.visit(&self.1);
}
}
impl Rule for AddWithLabel {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
*self.0.value() + *self.1.value()
}
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "test")
}
}
builder.set_output(AddWithLabel(a, b));
let graph = builder.build().unwrap(); let graph = builder.build().unwrap();
println!("{}", graph.as_dot_string()); println!("{}", graph.as_dot_string());
assert_eq!( assert_eq!(
@ -697,7 +718,7 @@ mod tests {
r#"digraph { r#"digraph {
0 [label ="ConstNode<i32> (id=0)"] 0 [label ="ConstNode<i32> (id=0)"]
1 [label ="ConstNode<i32> (id=1)"] 1 [label ="ConstNode<i32> (id=1)"]
2 [label ="RuleNode<compute_graph::tests::Add> (id=2)"] 2 [label ="RuleNode<compute_graph::tests::graphviz::AddWithLabel>(test) (id=2)"]
0 -> 2 [] 0 -> 2 []
1 -> 2 [] 1 -> 2 []
} }

View File

@ -85,16 +85,13 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
/// ///
/// This trait is used to determine, when a node is invalidated, whether its value has truly changed /// This trait is used to determine, when a node is invalidated, whether its value has truly changed
/// and thus whether downstream nodes need to be invalidated too. /// and thus whether downstream nodes need to be invalidated too.
///
/// A blanket implementation of this trait for all types implementing `PartialEq` is provided.
pub trait NodeValue: 'static { pub trait NodeValue: 'static {
/// Whether self is equal, for the purposes of graph invalidation, from other. /// Whether self is equal, for the purposes of graph invalidation, from other.
/// ///
/// This method should be conservative. That is, if the equality of the two values cannot be affirmatively /// This method should be conservative. That is, if the equality of the two values cannot be affirmatively
/// determined, this method should return `false`. /// determined, this method should return `false`.
/// ///
/// The default implementation of this method always returns `false`, so any non-`PartialEq` type can /// The default implementation of this method always returns `false`.
/// implement this trait simply:
/// ///
/// ```rust /// ```rust
/// # use compute_graph::node::NodeValue; /// # use compute_graph::node::NodeValue;
@ -109,12 +106,36 @@ pub trait NodeValue: 'static {
} }
} }
impl<T: PartialEq + 'static> NodeValue for T { impl NodeValue for () {
fn node_value_eq(&self, _other: &Self) -> bool {
true
}
}
impl NodeValue for i32 {
fn node_value_eq(&self, other: &Self) -> bool { fn node_value_eq(&self, other: &Self) -> bool {
self == other self == other
} }
} }
impl<T: NodeValue> NodeValue for Option<T> {
fn node_value_eq(&self, other: &Self) -> bool {
match (self, other) {
(Some(s), Some(o)) => s.node_value_eq(o),
_ => false,
}
}
}
impl<T: NodeValue, E: 'static> NodeValue for Result<T, E> {
fn node_value_eq(&self, other: &Self) -> bool {
match (self, other) {
(Ok(s), Ok(o)) => s.node_value_eq(o),
_ => false,
}
}
}
pub(crate) struct ConstNode<V, S> { pub(crate) struct ConstNode<V, S> {
value: Rc<RefCell<Option<V>>>, value: Rc<RefCell<Option<V>>>,
synchronicity: std::marker::PhantomData<S>, synchronicity: std::marker::PhantomData<S>,
@ -256,9 +277,21 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
} }
} }
impl<R, V, S> std::fmt::Debug for RuleNode<R, V, S> { struct RuleLabel<'a, R: Rule>(&'a R);
impl<'a, R: Rule> std::fmt::Display for RuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RuleNode<{}>", std::any::type_name::<R>()) self.0.node_label(f)
}
}
impl<R: Rule, V, S> std::fmt::Debug for RuleNode<R, V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RuleNode<{}>({})",
std::any::type_name::<R>(),
RuleLabel(&self.rule)
)
} }
} }
@ -374,8 +407,20 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
} }
} }
impl<R, V> std::fmt::Debug for AsyncRuleNode<R, V> { struct AsyncRuleLabel<'a, R: AsyncRule>(&'a R);
impl<'a, R: AsyncRule> std::fmt::Display for AsyncRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "AsyncRuleNode<{}>", std::any::type_name::<R>()) self.0.node_label(f)
}
}
impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AsyncRuleNode<{}>({})",
std::any::type_name::<R>(),
AsyncRuleLabel(&self.rule)
)
} }
} }

View File

@ -38,6 +38,11 @@ pub trait Rule: InputVisitable + 'static {
/// 2. A rule will never be evaluated before _all_ of its dependencies up-to-date. That is, it will never /// 2. A rule will never be evaluated before _all_ of its dependencies up-to-date. That is, it will never
/// be evaluated with mix of valid and invalid dependencies. /// be evaluated with mix of valid and invalid dependencies.
fn evaluate(&mut self) -> Self::Output; fn evaluate(&mut self) -> Self::Output;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
} }
/// A rule produces a value for a graph node asynchronously. /// A rule produces a value for a graph node asynchronously.
@ -64,6 +69,11 @@ pub trait AsyncRule: InputVisitable + 'static {
/// ///
/// See [`Rule::evaluate`] for additional details; the same considerations apply. /// See [`Rule::evaluate`] for additional details; the same considerations apply.
fn evaluate(&mut self) -> impl Future<Output = Self::Output> + '_; fn evaluate(&mut self) -> impl Future<Output = Self::Output> + '_;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
} }
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited. /// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.