From 1d1673e5ee86651f3f8942c9092525b4ab79de27 Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Fri, 1 Nov 2024 11:35:17 -0400 Subject: [PATCH] Only update downstream nodes if an input changes --- crates/graph/src/builder.rs | 6 +-- crates/graph/src/lib.rs | 81 ++++++++++++++++++++++--------- crates/graph/src/node.rs | 51 ++++++++++++++----- crates/graph/src/synchronicity.rs | 42 ++++++++-------- 4 files changed, 121 insertions(+), 59 deletions(-) diff --git a/crates/graph/src/builder.rs b/crates/graph/src/builder.rs index 157b80d..eb3b394 100644 --- a/crates/graph/src/builder.rs +++ b/crates/graph/src/builder.rs @@ -1,4 +1,4 @@ -use crate::node::{AsyncRuleNode, ConstNode, Node, RuleNode}; +use crate::node::{AsyncRuleNode, ConstNode, Node, NodeValue, RuleNode}; use crate::util; use crate::{ AsyncRule, Asynchronous, ErasedNode, Graph, Input, InvalidationSignal, NodeGraph, NodeId, Rule, @@ -42,7 +42,7 @@ impl GraphBuilder { self.output = Some(input); } - fn add_node(&mut self, node: impl Node + 'static) -> Input { + fn add_node(&mut self, node: impl Node + 'static) -> Input { let value = Rc::clone(node.value_rc()); let erased = ErasedNode::new(node); let idx = self.node_graph.borrow_mut().add_node(erased); @@ -52,7 +52,7 @@ impl GraphBuilder { } } - pub fn add_value(&mut self, value: V) -> Input { + pub fn add_value(&mut self, value: V) -> Input { return self.add_node(ConstNode::new(value)); } diff --git a/crates/graph/src/lib.rs b/crates/graph/src/lib.rs index 26c68da..c069870 100644 --- a/crates/graph/src/lib.rs +++ b/crates/graph/src/lib.rs @@ -4,7 +4,7 @@ mod synchronicity; mod util; use builder::{BuildGraphError, GraphBuilder}; -use node::ErasedNode; +use node::{ErasedNode, NodeValue}; use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef}; use std::cell::{Cell, Ref, RefCell}; @@ -124,17 +124,19 @@ impl Graph { let node = &mut graph[idx]; if !node.is_valid() { // Update this node - node.update(); + let value_changed = node.update(); - // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) - let dependents = graph - .edges_directed(idx, petgraph::Direction::Outgoing) - .map(|edge| edge.target()) - // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate - .collect::>(); - for dependent_idx in dependents { - let dependent = &mut graph[dependent_idx]; - dependent.invalidate(); + if value_changed { + // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) + let dependents = graph + .edges_directed(idx, petgraph::Direction::Outgoing) + .map(|edge| edge.target()) + // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate + .collect::>(); + for dependent_idx in dependents { + let dependent = &mut graph[dependent_idx]; + dependent.invalidate(); + } } } } @@ -162,17 +164,19 @@ impl Graph { let node = &mut graph[idx]; if !node.is_valid() { // Update this node - node.update().await; + let value_changed = node.update().await; - // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) - let dependents = graph - .edges_directed(idx, petgraph::Direction::Outgoing) - .map(|edge| edge.target()) - // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate - .collect::>(); - for dependent_idx in dependents { - let dependent = &mut graph[dependent_idx]; - dependent.invalidate(); + if value_changed { + // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) + let dependents = graph + .edges_directed(idx, petgraph::Direction::Outgoing) + .map(|edge| edge.target()) + // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate + .collect::>(); + for dependent_idx in dependents { + let dependent = &mut graph[dependent_idx]; + dependent.invalidate(); + } } } } @@ -236,7 +240,7 @@ impl InvalidationSignal { // TODO: i really want Input to be able to implement Deref somehow pub trait Rule: 'static { - type Output; + type Output: NodeValue; fn visit_inputs(&self, visitor: &mut impl InputVisitor); @@ -244,7 +248,7 @@ pub trait Rule: 'static { } pub trait AsyncRule: 'static { - type Output: 'static; + type Output: NodeValue; fn visit_inputs(&self, visitor: &mut impl InputVisitor); @@ -260,7 +264,7 @@ mod tests { use super::*; struct ConstantRule(T); - impl Rule for ConstantRule { + impl Rule for ConstantRule { type Output = T; fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} fn evaluate(&mut self) -> Self::Output { @@ -490,4 +494,33 @@ mod tests { let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), NonCloneable); } + + #[test] + fn only_update_downstream_nodes_if_value_changes() { + let mut builder = GraphBuilder::new(); + let mut invalidate = None; + let a = builder.add_invalidatable_rule(|inv| { + invalidate = Some(inv); + ConstantRule(0) + }); + struct IncAdd(Input, i32); + impl Rule for IncAdd { + type Output = i32; + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit(&self.0); + } + fn evaluate(&mut self) -> Self::Output { + self.1 += 1; + *self.0.value() + self.1 + } + } + builder.set_output(IncAdd(a, 0)); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate(), 1); + + // IncAdd should not be evaluated again, despite its input being invalidated, so the output should be unchanged + invalidate.unwrap().invalidate(); + assert!(!graph.is_output_valid()); + assert_eq!(*graph.evaluate(), 1); + } } diff --git a/crates/graph/src/node.rs b/crates/graph/src/node.rs index a4e3da5..15fafe5 100644 --- a/crates/graph/src/node.rs +++ b/crates/graph/src/node.rs @@ -13,7 +13,7 @@ pub(crate) struct ErasedNode { } impl ErasedNode { - pub(crate) fn new + 'static, V: 'static>(base: N) -> Self { + pub(crate) fn new + 'static, V: NodeValue>(base: N) -> Self { // i don't love the double boxing, but i'm not sure how else to do this let thing: Box> = Box::new(base); let any: Box = Box::new(thing); @@ -47,18 +47,18 @@ impl ErasedNode { } impl ErasedNode { - pub(crate) fn update(&mut self) { + pub(crate) fn update(&mut self) -> bool { (self.update)(&mut self.any) } } impl ErasedNode { - pub(crate) async fn update(&mut self) { + pub(crate) async fn update(&mut self) -> bool { (self.update)(&mut self.any).await } } -pub(crate) trait Node { +pub(crate) trait Node { fn is_valid(&self) -> bool; fn invalidate(&mut self); fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()); @@ -66,6 +66,16 @@ pub(crate) trait Node { fn value_rc(&self) -> &Rc>>; } +pub trait NodeValue: 'static { + fn node_value_eq(&self, other: &Self) -> bool; +} + +impl NodeValue for T { + fn node_value_eq(&self, other: &Self) -> bool { + self == other + } +} + pub(crate) struct ConstNode { value: Rc>>, synchronicity: std::marker::PhantomData, @@ -80,7 +90,7 @@ impl ConstNode { } } -impl Node for ConstNode { +impl Node for ConstNode { fn is_valid(&self) -> bool { true } @@ -136,10 +146,19 @@ impl Node for RuleNode } fn update(&mut self) -> S::UpdateResult<'_> { - let new_value = self.rule.evaluate(); self.valid = true; - *self.value.borrow_mut() = Some(new_value); - S::make_update_result() + + let new_value = self.rule.evaluate(); + let mut value = self.value.borrow_mut(); + let value_changed = value + .as_ref() + .map_or(true, |v| !v.node_value_eq(&new_value)); + + if value_changed { + *value = Some(new_value); + } + + S::make_update_result(value_changed) } fn value_rc(&self) -> &Rc>> { @@ -192,9 +211,19 @@ impl Node for AsyncRuleNode } impl AsyncRuleNode { - async fn do_update(&mut self) { - let new_value = self.rule.evaluate().await; + async fn do_update(&mut self) -> bool { self.valid = true; - *self.value.borrow_mut() = Some(new_value); + + let new_value = self.rule.evaluate().await; + let mut value = self.value.borrow_mut(); + let value_changed = value + .as_ref() + .map_or(true, |v| !v.node_value_eq(&new_value)); + + if value_changed { + *value = Some(new_value); + } + + value_changed } } diff --git a/crates/graph/src/synchronicity.rs b/crates/graph/src/synchronicity.rs index f5e95e9..4cf7bfa 100644 --- a/crates/graph/src/synchronicity.rs +++ b/crates/graph/src/synchronicity.rs @@ -1,53 +1,53 @@ -use crate::node::Node; +use crate::node::{Node, NodeValue}; use std::any::Any; use std::future::Future; use std::pin::Pin; pub trait Synchronicity: 'static { type UpdateFn; - fn make_update_fn() -> Self::UpdateFn; + fn make_update_fn() -> Self::UpdateFn; type UpdateResult<'a>; - fn make_update_result<'a>() -> Self::UpdateResult<'a>; + fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a>; } pub struct Synchronous; impl Synchronicity for Synchronous { - type UpdateFn = Box) -> ()>; + type UpdateFn = Box) -> bool>; - fn make_update_fn() -> Self::UpdateFn { + fn make_update_fn() -> Self::UpdateFn { Box::new(|any| { let x = any.downcast_mut::>>().unwrap(); - x.update(); + x.update() }) } - type UpdateResult<'a> = (); + type UpdateResult<'a> = bool; - fn make_update_result<'a>() -> Self::UpdateResult<'a> {} + fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> { + result + } } pub struct Asynchronous; impl Synchronicity for Asynchronous { type UpdateFn = - Box Fn(&'a mut Box) -> Pin + 'a>>>; + Box Fn(&'a mut Box) -> Pin + 'a>>>; - fn make_update_fn() -> Self::UpdateFn { - Box::new(|any| Box::pin(Asynchronous::do_async_update::(any))) + fn make_update_fn() -> Self::UpdateFn { + Box::new(|any| { + Box::pin({ + let x = any.downcast_mut::>>().unwrap(); + x.update() + }) + }) } - type UpdateResult<'a> = Pin + 'a>>; + type UpdateResult<'a> = Pin + 'a>>; - fn make_update_result<'a>() -> Self::UpdateResult<'a> { - Box::pin(std::future::ready(())) - } -} - -impl Asynchronous { - async fn do_async_update(any: &mut Box) { - let x = any.downcast_mut::>>().unwrap(); - x.update().await; + fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> { + Box::pin(std::future::ready(result)) } }