From 640c0ab6202c1ae0f20d8c0e68dc893800c8b01d Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Tue, 31 Dec 2024 18:49:00 -0500 Subject: [PATCH] Allow dynamic nodes to add invalidatable rules --- crates/compute_graph/src/builder.rs | 10 +- crates/compute_graph/src/lib.rs | 144 ++++++++++++++++++++-------- crates/compute_graph/src/node.rs | 28 +++++- crates/compute_graph/src/rule.rs | 23 +++-- src/generator/posts.rs | 14 +-- src/generator/tags.rs | 11 ++- 6 files changed, 157 insertions(+), 73 deletions(-) diff --git a/crates/compute_graph/src/builder.rs b/crates/compute_graph/src/builder.rs index fe61194..cb29d33 100644 --- a/crates/compute_graph/src/builder.rs +++ b/crates/compute_graph/src/builder.rs @@ -179,17 +179,9 @@ impl GraphBuilder { } fn make_invalidation_signal(&self, input: &Input) -> InvalidationSignal { - let node_idx = Rc::clone(&input.node_idx); let graph = Rc::clone(&self.node_graph); let graph_is_valid = Rc::clone(&self.is_valid); - InvalidationSignal { - do_invalidate: Rc::new(Box::new(move || { - graph_is_valid.set(false); - let mut graph = graph.borrow_mut(); - let node = &mut graph[node_idx.get().unwrap()]; - node.invalidate(); - })), - } + InvalidationSignal::new(input, graph, graph_is_valid) } /// Adds a node to the graph whose output is additional nodes produced by the given rule. diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index 8b8176c..e9dffdf 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -336,7 +336,7 @@ impl Graph { let node = &mut graph[idx]; if !node.is_valid() { // Update this node - let mut ctx = NodeUpdateContext::new(); + let mut ctx = NodeUpdateContext::new(self); node.update(&mut ctx); drop(graph); @@ -389,7 +389,7 @@ impl Graph { let node = &mut graph[idx]; if !node.is_valid() { // Update this node - let mut ctx = NodeUpdateContext::new(); + let mut ctx = NodeUpdateContext::new(self); node.update(&mut ctx).await; drop(graph); @@ -444,6 +444,22 @@ pub struct InvalidationSignal { } impl InvalidationSignal { + pub(crate) fn new( + input: &Input, + graph: Rc>>, + graph_is_valid: Rc>, + ) -> Self { + let node_idx = Rc::clone(&input.node_idx); + InvalidationSignal { + do_invalidate: Rc::new(Box::new(move || { + graph_is_valid.set(false); + let mut graph = graph.borrow_mut(); + let node = &mut graph[node_idx.get().unwrap()]; + node.invalidate(); + })), + } + } + /// Tell the graph that the node corresponding to this signal is now invalid. /// /// Note: Calling this method does not trigger a graph evaluation, it merely marks the corresponding @@ -724,23 +740,24 @@ mod tests { assert_eq!(*graph.evaluate(), NonCloneable); } + struct IncAdd(Input, i32); + impl InputVisitable for IncAdd { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit(&self.0); + } + } + impl Rule for IncAdd { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + self.1 += 1; + *self.0.value() + self.1 + } + } + #[test] fn only_update_downstream_nodes_if_value_changes() { let mut builder = GraphBuilder::new(); let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0)); - struct IncAdd(Input, i32); - impl InputVisitable for IncAdd { - fn visit_inputs(&self, visitor: &mut impl InputVisitor) { - visitor.visit(&self.0); - } - } - impl Rule for IncAdd { - type Output = i32; - fn evaluate(&mut self) -> Self::Output { - self.1 += 1; - *self.0.value() + self.1 - } - } builder.set_output(IncAdd(a, 0)); let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), 1); @@ -811,6 +828,19 @@ mod tests { ) } + struct DynamicSum(DynamicInput); + impl InputVisitable for DynamicSum { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit_dynamic(&self.0); + } + } + impl Rule for DynamicSum { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + self.0.value().inputs.iter().map(|i| *i.value()).sum() + } + } + #[test] fn dynamic_rule() { let mut builder = GraphBuilder::new(); @@ -832,7 +862,8 @@ mod tests { ) -> Vec> { let count = *self.count.value(); for i in 1..=count { - self.node_factory.add_rule(ctx, i, || ConstantRule::new(i)); + self.node_factory + .add_rule(ctx, i, |ctx| ctx.add_rule(ConstantRule::new(i))); } self.node_factory.all_nodes(ctx) } @@ -841,19 +872,7 @@ mod tests { count, node_factory: DynamicNodeFactory::new(), }); - struct Sum(DynamicInput); - impl InputVisitable for Sum { - fn visit_inputs(&self, visitor: &mut impl InputVisitor) { - visitor.visit_dynamic(&self.0); - } - } - impl Rule for Sum { - type Output = i32; - fn evaluate(&mut self) -> Self::Output { - self.0.value().inputs.iter().map(|i| *i.value()).sum() - } - } - builder.set_output(Sum(all_inputs)); + builder.set_output(DynamicSum(all_inputs)); let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), 1); set_count.set_value(2); @@ -891,19 +910,7 @@ mod tests { } } let all_inputs = builder.add_async_dynamic_rule(CountUpTo(count, vec![])); - struct Sum(DynamicInput); - impl InputVisitable for Sum { - fn visit_inputs(&self, visitor: &mut impl InputVisitor) { - visitor.visit_dynamic(&self.0); - } - } - impl Rule for Sum { - type Output = i32; - fn evaluate(&mut self) -> Self::Output { - self.0.value().inputs.iter().map(|i| *i.value()).sum() - } - } - builder.set_output(Sum(all_inputs)); + builder.set_output(DynamicSum(all_inputs)); let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate_async().await, 1); set_count.set_value(2); @@ -912,4 +919,57 @@ mod tests { assert_eq!(*graph.evaluate_async().await, 10); println!("{}", graph.as_dot_string()); } + + #[test] + fn dynamic_invalidatable_rule() { + let mut builder = GraphBuilder::new(); + let (count, set_count) = builder.add_invalidatable_value(1); + struct CountUpTo { + count: Input, + signals: Rc>>, + node_factory: DynamicNodeFactory, + } + impl InputVisitable for CountUpTo { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit(&self.count); + } + } + impl DynamicRule for CountUpTo { + type ChildOutput = i32; + fn evaluate( + &mut self, + ctx: &mut impl rule::DynamicRuleContext, + ) -> Vec> { + let count = *self.count.value(); + for i in 1..=count { + self.node_factory.add_rule(ctx, i, |ctx| { + let constant = ctx.add_rule(ConstantRule::new(i)); + let (input, signal) = ctx.add_invalidatable_rule(IncAdd(constant, 0)); + self.signals.borrow_mut().push(signal); + input + }); + } + self.node_factory.all_nodes(ctx) + } + } + let signals = Rc::new(RefCell::new(vec![])); + let all_inputs = builder.add_dynamic_rule(CountUpTo { + count, + signals: Rc::clone(&signals), + node_factory: DynamicNodeFactory::new(), + }); + builder.set_output(DynamicSum(all_inputs)); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate(), 2); + for signal in signals.borrow().iter() { + signal.invalidate(); + } + assert_eq!(*graph.evaluate(), 3); + set_count.set_value(2); + assert_eq!(*graph.evaluate(), 6); // new const node has value 2, IncAdd initially adds 1 + for signal in signals.borrow().iter() { + signal.invalidate(); + } + assert_eq!(*graph.evaluate(), 8); + } } diff --git a/crates/compute_graph/src/node.rs b/crates/compute_graph/src/node.rs index 8cf6cb4..ca1987f 100644 --- a/crates/compute_graph/src/node.rs +++ b/crates/compute_graph/src/node.rs @@ -3,7 +3,7 @@ use crate::rule::{ DynamicRuleContext, InputVisitable, Rule, }; use crate::synchronicity::{Asynchronous, Synchronicity}; -use crate::{Input, InputVisitor, NodeId, Synchronous}; +use crate::{Graph, Input, InputVisitor, InvalidationSignal, NodeGraph, NodeId, Synchronous}; use quote::ToTokens; use std::any::Any; use std::cell::{Cell, RefCell}; @@ -25,14 +25,18 @@ pub(crate) struct ErasedNode { } pub(crate) struct NodeUpdateContext { + pub(crate) graph: Rc>>, + pub(crate) graph_is_valid: Rc>, pub(crate) invalidate_dependent_nodes: bool, pub(crate) removed_nodes: Vec, pub(crate) added_nodes: Vec<(ErasedNode, Rc>>)>, } impl NodeUpdateContext { - pub(crate) fn new() -> Self { + pub(crate) fn new(graph: &Graph) -> Self { Self { + graph: Rc::clone(&graph.node_graph), + graph_is_valid: Rc::clone(&graph.is_valid), invalidate_dependent_nodes: false, removed_nodes: vec![], added_nodes: vec![], @@ -562,6 +566,19 @@ impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S { self.add_node(RuleNode::new(rule)) } + + fn add_invalidatable_rule(&mut self, rule: R) -> (Input, InvalidationSignal) + where + R: Rule, + { + let input = self.add_rule(rule); + let signal = InvalidationSignal::new( + &input, + Rc::clone(&self.0.graph), + Rc::clone(&self.0.graph_is_valid), + ); + (input, signal) + } } struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R); @@ -658,6 +675,13 @@ impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> { { DynamicRuleUpdateContext(self.0).add_rule(rule) } + + fn add_invalidatable_rule(&mut self, rule: R) -> (Input, InvalidationSignal) + where + R: Rule, + { + DynamicRuleUpdateContext(self.0).add_invalidatable_rule(rule) + } } impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> { diff --git a/crates/compute_graph/src/rule.rs b/crates/compute_graph/src/rule.rs index a23691c..0dc5f68 100644 --- a/crates/compute_graph/src/rule.rs +++ b/crates/compute_graph/src/rule.rs @@ -1,5 +1,5 @@ use crate::node::{DynamicRuleOutput, NodeValue}; -use crate::NodeId; +use crate::{InvalidationSignal, NodeId}; pub use compute_graph_macros::InputVisitable; use std::cell::{Cell, Ref, RefCell}; use std::collections::{HashMap, HashSet}; @@ -116,6 +116,11 @@ pub trait DynamicRuleContext { fn add_rule(&mut self, rule: R) -> Input where R: Rule; + + /// Adds an externally-invalidatable node whose value is produced using the given rule. + fn add_invalidatable_rule(&mut self, rule: R) -> (Input, InvalidationSignal) + where + R: Rule; } /// Helper type for working with [`DynamicRule`]s. @@ -143,13 +148,13 @@ impl DynamicNodeFactory { /// /// This method must be called for every node that is part of the output. The `build` function /// will only be called for nodes that have not previously been built. - pub fn add_rule(&mut self, ctx: &mut impl DynamicRuleContext, id: ID, build: F) + pub fn add_rule(&mut self, ctx: &mut C, id: ID, build: F) where - F: FnOnce() -> R, - R: Rule, + C: DynamicRuleContext, + F: FnOnce(&mut C) -> Input, { if !self.existing_nodes.contains_key(&id) { - let input = ctx.add_rule(build()); + let input = build(ctx); self.existing_nodes.insert(id.clone(), input); } self.ids_added_this_evaluation.insert(id); @@ -158,13 +163,13 @@ impl DynamicNodeFactory { /// Registers a node that is part of the output. /// /// See [`DynamicNodeFactory::add_rule`]. - pub fn add_async_rule(&mut self, ctx: &mut impl AsyncDynamicRuleContext, id: ID, build: F) + pub fn add_async_rule(&mut self, ctx: &mut C, id: ID, build: F) where - F: FnOnce() -> R, - R: AsyncRule, + C: AsyncDynamicRuleContext, + F: FnOnce(&mut C) -> Input, { if !self.existing_nodes.contains_key(&id) { - let input = ctx.add_async_rule(build()); + let input = build(ctx); self.existing_nodes.insert(id.clone(), input); } self.ids_added_this_evaluation.insert(id); diff --git a/src/generator/posts.rs b/src/generator/posts.rs index 399dcc2..7be8072 100644 --- a/src/generator/posts.rs +++ b/src/generator/posts.rs @@ -94,8 +94,9 @@ impl DynamicRule for MakeReadNodes { type ChildOutput = ReadPostOutput; fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec> { for file in self.files.value().iter() { - self.node_factory - .add_rule(ctx, file.clone(), || ReadPost { path: file.clone() }); + self.node_factory.add_rule(ctx, file.clone(), |ctx| { + ctx.add_rule(ReadPost { path: file.clone() }) + }); } self.node_factory.all_nodes(ctx) } @@ -154,8 +155,8 @@ impl DynamicRule for MakeExtractMetadatas { for post_input in self.posts.value().inputs.iter() { let post_ = post_input.value(); let post = post_.as_ref().unwrap(); - self.node_factory.add_rule(ctx, post.path.clone(), || { - ExtractMetadata(post_input.clone()) + self.node_factory.add_rule(ctx, post.path.clone(), |ctx| { + ctx.add_rule(ExtractMetadata(post_input.clone())) }); } self.node_factory.all_nodes(ctx) @@ -200,8 +201,9 @@ impl DynamicRule for MakeWritePosts { fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec> { for post_input in self.posts.value().inputs.iter() { if let Some(post) = post_input.value().as_ref() { - self.node_factory - .add_rule(ctx, post.path.clone(), || WritePost(post_input.clone())); + self.node_factory.add_rule(ctx, post.path.clone(), |ctx| { + ctx.add_rule(WritePost(post_input.clone())) + }); } } self.node_factory.all_nodes(ctx) diff --git a/src/generator/tags.rs b/src/generator/tags.rs index 439363f..ccb6da2 100644 --- a/src/generator/tags.rs +++ b/src/generator/tags.rs @@ -51,11 +51,12 @@ impl DynamicRule for MakePostsByTags { } } for (slug, name) in all_tags { - self.node_factory - .add_rule(ctx, slug.clone(), || PostsByTag { + self.node_factory.add_rule(ctx, slug.clone(), |ctx| { + ctx.add_rule(PostsByTag { posts: self.posts.clone(), tag: Tag { slug, name }, - }); + }) + }); } self.node_factory.all_nodes(ctx) } @@ -133,8 +134,8 @@ impl DynamicRule for MakeWriteTagPages { for tag_input in self.tags.value().inputs.iter() { let tag_and_posts = tag_input.value(); self.node_factory - .add_rule(ctx, tag_and_posts.tag.slug.clone(), || { - WriteTag(tag_input.clone()) + .add_rule(ctx, tag_and_posts.tag.slug.clone(), |ctx| { + ctx.add_rule(WriteTag(tag_input.clone())) }); } self.node_factory.all_nodes(ctx)