Only update downstream nodes if an input changes

This commit is contained in:
Shadowfacts 2024-11-01 11:35:17 -04:00
parent de025dc138
commit 1d1673e5ee
4 changed files with 121 additions and 59 deletions

View File

@ -1,4 +1,4 @@
use crate::node::{AsyncRuleNode, ConstNode, Node, RuleNode}; use crate::node::{AsyncRuleNode, ConstNode, Node, NodeValue, RuleNode};
use crate::util; use crate::util;
use crate::{ use crate::{
AsyncRule, Asynchronous, ErasedNode, Graph, Input, InvalidationSignal, NodeGraph, NodeId, Rule, AsyncRule, Asynchronous, ErasedNode, Graph, Input, InvalidationSignal, NodeGraph, NodeId, Rule,
@ -42,7 +42,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
self.output = Some(input); self.output = Some(input);
} }
fn add_node<V: 'static>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> { fn add_node<V: NodeValue>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
let value = Rc::clone(node.value_rc()); let value = Rc::clone(node.value_rc());
let erased = ErasedNode::new(node); let erased = ErasedNode::new(node);
let idx = self.node_graph.borrow_mut().add_node(erased); let idx = self.node_graph.borrow_mut().add_node(erased);
@ -52,7 +52,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
} }
} }
pub fn add_value<V: 'static>(&mut self, value: V) -> Input<V> { pub fn add_value<V: NodeValue>(&mut self, value: V) -> Input<V> {
return self.add_node(ConstNode::new(value)); return self.add_node(ConstNode::new(value));
} }

View File

@ -4,7 +4,7 @@ mod synchronicity;
mod util; mod util;
use builder::{BuildGraphError, GraphBuilder}; use builder::{BuildGraphError, GraphBuilder};
use node::ErasedNode; use node::{ErasedNode, NodeValue};
use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef}; use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
use std::cell::{Cell, Ref, RefCell}; use std::cell::{Cell, Ref, RefCell};
@ -124,17 +124,19 @@ impl<O: 'static> Graph<O, Synchronous> {
let node = &mut graph[idx]; let node = &mut graph[idx];
if !node.is_valid() { if !node.is_valid() {
// Update this node // Update this node
node.update(); let value_changed = node.update();
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) if value_changed {
let dependents = graph // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
.edges_directed(idx, petgraph::Direction::Outgoing) let dependents = graph
.map(|edge| edge.target()) .edges_directed(idx, petgraph::Direction::Outgoing)
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate .map(|edge| edge.target())
.collect::<Vec<_>>(); // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
for dependent_idx in dependents { .collect::<Vec<_>>();
let dependent = &mut graph[dependent_idx]; for dependent_idx in dependents {
dependent.invalidate(); let dependent = &mut graph[dependent_idx];
dependent.invalidate();
}
} }
} }
} }
@ -162,17 +164,19 @@ impl<O: 'static> Graph<O, Asynchronous> {
let node = &mut graph[idx]; let node = &mut graph[idx];
if !node.is_valid() { if !node.is_valid() {
// Update this node // Update this node
node.update().await; let value_changed = node.update().await;
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) if value_changed {
let dependents = graph // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
.edges_directed(idx, petgraph::Direction::Outgoing) let dependents = graph
.map(|edge| edge.target()) .edges_directed(idx, petgraph::Direction::Outgoing)
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate .map(|edge| edge.target())
.collect::<Vec<_>>(); // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
for dependent_idx in dependents { .collect::<Vec<_>>();
let dependent = &mut graph[dependent_idx]; for dependent_idx in dependents {
dependent.invalidate(); let dependent = &mut graph[dependent_idx];
dependent.invalidate();
}
} }
} }
} }
@ -236,7 +240,7 @@ impl<S: Synchronicity> InvalidationSignal<S> {
// TODO: i really want Input to be able to implement Deref somehow // TODO: i really want Input to be able to implement Deref somehow
pub trait Rule: 'static { pub trait Rule: 'static {
type Output; type Output: NodeValue;
fn visit_inputs(&self, visitor: &mut impl InputVisitor); fn visit_inputs(&self, visitor: &mut impl InputVisitor);
@ -244,7 +248,7 @@ pub trait Rule: 'static {
} }
pub trait AsyncRule: 'static { pub trait AsyncRule: 'static {
type Output: 'static; type Output: NodeValue;
fn visit_inputs(&self, visitor: &mut impl InputVisitor); fn visit_inputs(&self, visitor: &mut impl InputVisitor);
@ -260,7 +264,7 @@ mod tests {
use super::*; use super::*;
struct ConstantRule<T>(T); struct ConstantRule<T>(T);
impl<T: Clone + 'static> Rule for ConstantRule<T> { impl<T: Clone + NodeValue> Rule for ConstantRule<T> {
type Output = T; type Output = T;
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> Self::Output { fn evaluate(&mut self) -> Self::Output {
@ -490,4 +494,33 @@ mod tests {
let mut graph = builder.build().unwrap(); let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), NonCloneable); assert_eq!(*graph.evaluate(), NonCloneable);
} }
#[test]
fn only_update_downstream_nodes_if_value_changes() {
let mut builder = GraphBuilder::new();
let mut invalidate = None;
let a = builder.add_invalidatable_rule(|inv| {
invalidate = Some(inv);
ConstantRule(0)
});
struct IncAdd(Input<i32>, i32);
impl Rule for IncAdd {
type Output = i32;
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
fn evaluate(&mut self) -> Self::Output {
self.1 += 1;
*self.0.value() + self.1
}
}
builder.set_output(IncAdd(a, 0));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 1);
// IncAdd should not be evaluated again, despite its input being invalidated, so the output should be unchanged
invalidate.unwrap().invalidate();
assert!(!graph.is_output_valid());
assert_eq!(*graph.evaluate(), 1);
}
} }

View File

@ -13,7 +13,7 @@ pub(crate) struct ErasedNode<Synch: Synchronicity> {
} }
impl<S: Synchronicity> ErasedNode<S> { impl<S: Synchronicity> ErasedNode<S> {
pub(crate) fn new<N: Node<V, S> + 'static, V: 'static>(base: N) -> Self { pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self {
// i don't love the double boxing, but i'm not sure how else to do this // i don't love the double boxing, but i'm not sure how else to do this
let thing: Box<dyn Node<V, S>> = Box::new(base); let thing: Box<dyn Node<V, S>> = Box::new(base);
let any: Box<dyn Any> = Box::new(thing); let any: Box<dyn Any> = Box::new(thing);
@ -47,18 +47,18 @@ impl<S: Synchronicity> ErasedNode<S> {
} }
impl ErasedNode<Synchronous> { impl ErasedNode<Synchronous> {
pub(crate) fn update(&mut self) { pub(crate) fn update(&mut self) -> bool {
(self.update)(&mut self.any) (self.update)(&mut self.any)
} }
} }
impl ErasedNode<Asynchronous> { impl ErasedNode<Asynchronous> {
pub(crate) async fn update(&mut self) { pub(crate) async fn update(&mut self) -> bool {
(self.update)(&mut self.any).await (self.update)(&mut self.any).await
} }
} }
pub(crate) trait Node<Value: 'static, Synch: Synchronicity> { pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity> {
fn is_valid(&self) -> bool; fn is_valid(&self) -> bool;
fn invalidate(&mut self); fn invalidate(&mut self);
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()); fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ());
@ -66,6 +66,16 @@ pub(crate) trait Node<Value: 'static, Synch: Synchronicity> {
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>; fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
} }
pub trait NodeValue: 'static {
fn node_value_eq(&self, other: &Self) -> bool;
}
impl<T: PartialEq + 'static> NodeValue for T {
fn node_value_eq(&self, other: &Self) -> bool {
self == other
}
}
pub(crate) struct ConstNode<V, S> { pub(crate) struct ConstNode<V, S> {
value: Rc<RefCell<Option<V>>>, value: Rc<RefCell<Option<V>>>,
synchronicity: std::marker::PhantomData<S>, synchronicity: std::marker::PhantomData<S>,
@ -80,7 +90,7 @@ impl<V, S> ConstNode<V, S> {
} }
} }
impl<V: 'static, S: Synchronicity> Node<V, S> for ConstNode<V, S> { impl<V: NodeValue, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
fn is_valid(&self) -> bool { fn is_valid(&self) -> bool {
true true
} }
@ -136,10 +146,19 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
} }
fn update(&mut self) -> S::UpdateResult<'_> { fn update(&mut self) -> S::UpdateResult<'_> {
let new_value = self.rule.evaluate();
self.valid = true; self.valid = true;
*self.value.borrow_mut() = Some(new_value);
S::make_update_result() let new_value = self.rule.evaluate();
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
}
S::make_update_result(value_changed)
} }
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> { fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
@ -192,9 +211,19 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
} }
impl<R: AsyncRule> AsyncRuleNode<R, R::Output> { impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
async fn do_update(&mut self) { async fn do_update(&mut self) -> bool {
let new_value = self.rule.evaluate().await;
self.valid = true; self.valid = true;
*self.value.borrow_mut() = Some(new_value);
let new_value = self.rule.evaluate().await;
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
}
value_changed
} }
} }

View File

@ -1,53 +1,53 @@
use crate::node::Node; use crate::node::{Node, NodeValue};
use std::any::Any; use std::any::Any;
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
pub trait Synchronicity: 'static { pub trait Synchronicity: 'static {
type UpdateFn; type UpdateFn;
fn make_update_fn<V: 'static>() -> Self::UpdateFn; fn make_update_fn<V: NodeValue>() -> Self::UpdateFn;
type UpdateResult<'a>; type UpdateResult<'a>;
fn make_update_result<'a>() -> Self::UpdateResult<'a>; fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a>;
} }
pub struct Synchronous; pub struct Synchronous;
impl Synchronicity for Synchronous { impl Synchronicity for Synchronous {
type UpdateFn = Box<dyn Fn(&mut Box<dyn Any>) -> ()>; type UpdateFn = Box<dyn Fn(&mut Box<dyn Any>) -> bool>;
fn make_update_fn<V: 'static>() -> Self::UpdateFn { fn make_update_fn<V: NodeValue>() -> Self::UpdateFn {
Box::new(|any| { Box::new(|any| {
let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap(); let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
x.update(); x.update()
}) })
} }
type UpdateResult<'a> = (); type UpdateResult<'a> = bool;
fn make_update_result<'a>() -> Self::UpdateResult<'a> {} fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> {
result
}
} }
pub struct Asynchronous; pub struct Asynchronous;
impl Synchronicity for Asynchronous { impl Synchronicity for Asynchronous {
type UpdateFn = type UpdateFn =
Box<dyn for<'a> Fn(&'a mut Box<dyn Any>) -> Pin<Box<dyn Future<Output = ()> + 'a>>>; Box<dyn for<'a> Fn(&'a mut Box<dyn Any>) -> Pin<Box<dyn Future<Output = bool> + 'a>>>;
fn make_update_fn<V: 'static>() -> Self::UpdateFn { fn make_update_fn<V: NodeValue>() -> Self::UpdateFn {
Box::new(|any| Box::pin(Asynchronous::do_async_update::<V>(any))) Box::new(|any| {
Box::pin({
let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
x.update()
})
})
} }
type UpdateResult<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>; type UpdateResult<'a> = Pin<Box<dyn Future<Output = bool> + 'a>>;
fn make_update_result<'a>() -> Self::UpdateResult<'a> { fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> {
Box::pin(std::future::ready(())) Box::pin(std::future::ready(result))
}
}
impl Asynchronous {
async fn do_async_update<V: 'static>(any: &mut Box<dyn Any>) {
let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
x.update().await;
} }
} }