Only update downstream nodes if an input changes

This commit is contained in:
Shadowfacts 2024-11-01 11:35:17 -04:00
parent de025dc138
commit 1d1673e5ee
4 changed files with 121 additions and 59 deletions

View File

@ -1,4 +1,4 @@
use crate::node::{AsyncRuleNode, ConstNode, Node, RuleNode};
use crate::node::{AsyncRuleNode, ConstNode, Node, NodeValue, RuleNode};
use crate::util;
use crate::{
AsyncRule, Asynchronous, ErasedNode, Graph, Input, InvalidationSignal, NodeGraph, NodeId, Rule,
@ -42,7 +42,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
self.output = Some(input);
}
fn add_node<V: 'static>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
fn add_node<V: NodeValue>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
let value = Rc::clone(node.value_rc());
let erased = ErasedNode::new(node);
let idx = self.node_graph.borrow_mut().add_node(erased);
@ -52,7 +52,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
}
}
pub fn add_value<V: 'static>(&mut self, value: V) -> Input<V> {
pub fn add_value<V: NodeValue>(&mut self, value: V) -> Input<V> {
return self.add_node(ConstNode::new(value));
}

View File

@ -4,7 +4,7 @@ mod synchronicity;
mod util;
use builder::{BuildGraphError, GraphBuilder};
use node::ErasedNode;
use node::{ErasedNode, NodeValue};
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
use std::cell::{Cell, Ref, RefCell};
@ -124,17 +124,19 @@ impl<O: 'static> Graph<O, Synchronous> {
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
node.update();
let value_changed = node.update();
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
let dependents = graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.map(|edge| edge.target())
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
.collect::<Vec<_>>();
for dependent_idx in dependents {
let dependent = &mut graph[dependent_idx];
dependent.invalidate();
if value_changed {
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
let dependents = graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.map(|edge| edge.target())
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
.collect::<Vec<_>>();
for dependent_idx in dependents {
let dependent = &mut graph[dependent_idx];
dependent.invalidate();
}
}
}
}
@ -162,17 +164,19 @@ impl<O: 'static> Graph<O, Asynchronous> {
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
node.update().await;
let value_changed = node.update().await;
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
let dependents = graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.map(|edge| edge.target())
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
.collect::<Vec<_>>();
for dependent_idx in dependents {
let dependent = &mut graph[dependent_idx];
dependent.invalidate();
if value_changed {
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
let dependents = graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.map(|edge| edge.target())
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
.collect::<Vec<_>>();
for dependent_idx in dependents {
let dependent = &mut graph[dependent_idx];
dependent.invalidate();
}
}
}
}
@ -236,7 +240,7 @@ impl<S: Synchronicity> InvalidationSignal<S> {
// TODO: i really want Input to be able to implement Deref somehow
pub trait Rule: 'static {
type Output;
type Output: NodeValue;
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
@ -244,7 +248,7 @@ pub trait Rule: 'static {
}
pub trait AsyncRule: 'static {
type Output: 'static;
type Output: NodeValue;
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
@ -260,7 +264,7 @@ mod tests {
use super::*;
struct ConstantRule<T>(T);
impl<T: Clone + 'static> Rule for ConstantRule<T> {
impl<T: Clone + NodeValue> Rule for ConstantRule<T> {
type Output = T;
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> Self::Output {
@ -490,4 +494,33 @@ mod tests {
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), NonCloneable);
}
#[test]
fn only_update_downstream_nodes_if_value_changes() {
let mut builder = GraphBuilder::new();
let mut invalidate = None;
let a = builder.add_invalidatable_rule(|inv| {
invalidate = Some(inv);
ConstantRule(0)
});
struct IncAdd(Input<i32>, i32);
impl Rule for IncAdd {
type Output = i32;
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
fn evaluate(&mut self) -> Self::Output {
self.1 += 1;
*self.0.value() + self.1
}
}
builder.set_output(IncAdd(a, 0));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 1);
// IncAdd should not be evaluated again, despite its input being invalidated, so the output should be unchanged
invalidate.unwrap().invalidate();
assert!(!graph.is_output_valid());
assert_eq!(*graph.evaluate(), 1);
}
}

View File

@ -13,7 +13,7 @@ pub(crate) struct ErasedNode<Synch: Synchronicity> {
}
impl<S: Synchronicity> ErasedNode<S> {
pub(crate) fn new<N: Node<V, S> + 'static, V: 'static>(base: N) -> Self {
pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self {
// i don't love the double boxing, but i'm not sure how else to do this
let thing: Box<dyn Node<V, S>> = Box::new(base);
let any: Box<dyn Any> = Box::new(thing);
@ -47,18 +47,18 @@ impl<S: Synchronicity> ErasedNode<S> {
}
impl ErasedNode<Synchronous> {
pub(crate) fn update(&mut self) {
pub(crate) fn update(&mut self) -> bool {
(self.update)(&mut self.any)
}
}
impl ErasedNode<Asynchronous> {
pub(crate) async fn update(&mut self) {
pub(crate) async fn update(&mut self) -> bool {
(self.update)(&mut self.any).await
}
}
pub(crate) trait Node<Value: 'static, Synch: Synchronicity> {
pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity> {
fn is_valid(&self) -> bool;
fn invalidate(&mut self);
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ());
@ -66,6 +66,16 @@ pub(crate) trait Node<Value: 'static, Synch: Synchronicity> {
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
}
pub trait NodeValue: 'static {
fn node_value_eq(&self, other: &Self) -> bool;
}
impl<T: PartialEq + 'static> NodeValue for T {
fn node_value_eq(&self, other: &Self) -> bool {
self == other
}
}
pub(crate) struct ConstNode<V, S> {
value: Rc<RefCell<Option<V>>>,
synchronicity: std::marker::PhantomData<S>,
@ -80,7 +90,7 @@ impl<V, S> ConstNode<V, S> {
}
}
impl<V: 'static, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
impl<V: NodeValue, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
fn is_valid(&self) -> bool {
true
}
@ -136,10 +146,19 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
}
fn update(&mut self) -> S::UpdateResult<'_> {
let new_value = self.rule.evaluate();
self.valid = true;
*self.value.borrow_mut() = Some(new_value);
S::make_update_result()
let new_value = self.rule.evaluate();
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
}
S::make_update_result(value_changed)
}
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
@ -192,9 +211,19 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
}
impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
async fn do_update(&mut self) {
let new_value = self.rule.evaluate().await;
async fn do_update(&mut self) -> bool {
self.valid = true;
*self.value.borrow_mut() = Some(new_value);
let new_value = self.rule.evaluate().await;
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
}
value_changed
}
}

View File

@ -1,53 +1,53 @@
use crate::node::Node;
use crate::node::{Node, NodeValue};
use std::any::Any;
use std::future::Future;
use std::pin::Pin;
pub trait Synchronicity: 'static {
type UpdateFn;
fn make_update_fn<V: 'static>() -> Self::UpdateFn;
fn make_update_fn<V: NodeValue>() -> Self::UpdateFn;
type UpdateResult<'a>;
fn make_update_result<'a>() -> Self::UpdateResult<'a>;
fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a>;
}
pub struct Synchronous;
impl Synchronicity for Synchronous {
type UpdateFn = Box<dyn Fn(&mut Box<dyn Any>) -> ()>;
type UpdateFn = Box<dyn Fn(&mut Box<dyn Any>) -> bool>;
fn make_update_fn<V: 'static>() -> Self::UpdateFn {
fn make_update_fn<V: NodeValue>() -> Self::UpdateFn {
Box::new(|any| {
let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
x.update();
x.update()
})
}
type UpdateResult<'a> = ();
type UpdateResult<'a> = bool;
fn make_update_result<'a>() -> Self::UpdateResult<'a> {}
fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> {
result
}
}
pub struct Asynchronous;
impl Synchronicity for Asynchronous {
type UpdateFn =
Box<dyn for<'a> Fn(&'a mut Box<dyn Any>) -> Pin<Box<dyn Future<Output = ()> + 'a>>>;
Box<dyn for<'a> Fn(&'a mut Box<dyn Any>) -> Pin<Box<dyn Future<Output = bool> + 'a>>>;
fn make_update_fn<V: 'static>() -> Self::UpdateFn {
Box::new(|any| Box::pin(Asynchronous::do_async_update::<V>(any)))
fn make_update_fn<V: NodeValue>() -> Self::UpdateFn {
Box::new(|any| {
Box::pin({
let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
x.update()
})
})
}
type UpdateResult<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
type UpdateResult<'a> = Pin<Box<dyn Future<Output = bool> + 'a>>;
fn make_update_result<'a>() -> Self::UpdateResult<'a> {
Box::pin(std::future::ready(()))
}
}
impl Asynchronous {
async fn do_async_update<V: 'static>(any: &mut Box<dyn Any>) {
let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
x.update().await;
fn make_update_result<'a>(result: bool) -> Self::UpdateResult<'a> {
Box::pin(std::future::ready(result))
}
}