Allow dynamic nodes to add invalidatable rules

This commit is contained in:
Shadowfacts 2024-12-31 18:49:00 -05:00
parent f44f525c2c
commit 640c0ab620
6 changed files with 157 additions and 73 deletions

View File

@ -179,17 +179,9 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
}
fn make_invalidation_signal<V>(&self, input: &Input<V>) -> InvalidationSignal {
let node_idx = Rc::clone(&input.node_idx);
let graph = Rc::clone(&self.node_graph);
let graph_is_valid = Rc::clone(&self.is_valid);
InvalidationSignal {
do_invalidate: Rc::new(Box::new(move || {
graph_is_valid.set(false);
let mut graph = graph.borrow_mut();
let node = &mut graph[node_idx.get().unwrap()];
node.invalidate();
})),
}
InvalidationSignal::new(input, graph, graph_is_valid)
}
/// Adds a node to the graph whose output is additional nodes produced by the given rule.

View File

@ -336,7 +336,7 @@ impl<O: 'static> Graph<O, Synchronous> {
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
let mut ctx = NodeUpdateContext::new();
let mut ctx = NodeUpdateContext::new(self);
node.update(&mut ctx);
drop(graph);
@ -389,7 +389,7 @@ impl<O: 'static> Graph<O, Asynchronous> {
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
let mut ctx = NodeUpdateContext::new();
let mut ctx = NodeUpdateContext::new(self);
node.update(&mut ctx).await;
drop(graph);
@ -444,6 +444,22 @@ pub struct InvalidationSignal {
}
impl InvalidationSignal {
pub(crate) fn new<V, S: Synchronicity>(
input: &Input<V>,
graph: Rc<RefCell<NodeGraph<S>>>,
graph_is_valid: Rc<Cell<bool>>,
) -> Self {
let node_idx = Rc::clone(&input.node_idx);
InvalidationSignal {
do_invalidate: Rc::new(Box::new(move || {
graph_is_valid.set(false);
let mut graph = graph.borrow_mut();
let node = &mut graph[node_idx.get().unwrap()];
node.invalidate();
})),
}
}
/// Tell the graph that the node corresponding to this signal is now invalid.
///
/// Note: Calling this method does not trigger a graph evaluation, it merely marks the corresponding
@ -724,23 +740,24 @@ mod tests {
assert_eq!(*graph.evaluate(), NonCloneable);
}
struct IncAdd(Input<i32>, i32);
impl InputVisitable for IncAdd {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl Rule for IncAdd {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.1 += 1;
*self.0.value() + self.1
}
}
#[test]
fn only_update_downstream_nodes_if_value_changes() {
let mut builder = GraphBuilder::new();
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
struct IncAdd(Input<i32>, i32);
impl InputVisitable for IncAdd {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl Rule for IncAdd {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.1 += 1;
*self.0.value() + self.1
}
}
builder.set_output(IncAdd(a, 0));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 1);
@ -811,6 +828,19 @@ mod tests {
)
}
struct DynamicSum(DynamicInput<i32>);
impl InputVisitable for DynamicSum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit_dynamic(&self.0);
}
}
impl Rule for DynamicSum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.0.value().inputs.iter().map(|i| *i.value()).sum()
}
}
#[test]
fn dynamic_rule() {
let mut builder = GraphBuilder::new();
@ -832,7 +862,8 @@ mod tests {
) -> Vec<Input<Self::ChildOutput>> {
let count = *self.count.value();
for i in 1..=count {
self.node_factory.add_rule(ctx, i, || ConstantRule::new(i));
self.node_factory
.add_rule(ctx, i, |ctx| ctx.add_rule(ConstantRule::new(i)));
}
self.node_factory.all_nodes(ctx)
}
@ -841,19 +872,7 @@ mod tests {
count,
node_factory: DynamicNodeFactory::new(),
});
struct Sum(DynamicInput<i32>);
impl InputVisitable for Sum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit_dynamic(&self.0);
}
}
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.0.value().inputs.iter().map(|i| *i.value()).sum()
}
}
builder.set_output(Sum(all_inputs));
builder.set_output(DynamicSum(all_inputs));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 1);
set_count.set_value(2);
@ -891,19 +910,7 @@ mod tests {
}
}
let all_inputs = builder.add_async_dynamic_rule(CountUpTo(count, vec![]));
struct Sum(DynamicInput<i32>);
impl InputVisitable for Sum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit_dynamic(&self.0);
}
}
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.0.value().inputs.iter().map(|i| *i.value()).sum()
}
}
builder.set_output(Sum(all_inputs));
builder.set_output(DynamicSum(all_inputs));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate_async().await, 1);
set_count.set_value(2);
@ -912,4 +919,57 @@ mod tests {
assert_eq!(*graph.evaluate_async().await, 10);
println!("{}", graph.as_dot_string());
}
#[test]
fn dynamic_invalidatable_rule() {
let mut builder = GraphBuilder::new();
let (count, set_count) = builder.add_invalidatable_value(1);
struct CountUpTo {
count: Input<i32>,
signals: Rc<RefCell<Vec<InvalidationSignal>>>,
node_factory: DynamicNodeFactory<i32, i32>,
}
impl InputVisitable for CountUpTo {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.count);
}
}
impl DynamicRule for CountUpTo {
type ChildOutput = i32;
fn evaluate(
&mut self,
ctx: &mut impl rule::DynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
let count = *self.count.value();
for i in 1..=count {
self.node_factory.add_rule(ctx, i, |ctx| {
let constant = ctx.add_rule(ConstantRule::new(i));
let (input, signal) = ctx.add_invalidatable_rule(IncAdd(constant, 0));
self.signals.borrow_mut().push(signal);
input
});
}
self.node_factory.all_nodes(ctx)
}
}
let signals = Rc::new(RefCell::new(vec![]));
let all_inputs = builder.add_dynamic_rule(CountUpTo {
count,
signals: Rc::clone(&signals),
node_factory: DynamicNodeFactory::new(),
});
builder.set_output(DynamicSum(all_inputs));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 2);
for signal in signals.borrow().iter() {
signal.invalidate();
}
assert_eq!(*graph.evaluate(), 3);
set_count.set_value(2);
assert_eq!(*graph.evaluate(), 6); // new const node has value 2, IncAdd initially adds 1
for signal in signals.borrow().iter() {
signal.invalidate();
}
assert_eq!(*graph.evaluate(), 8);
}
}

View File

@ -3,7 +3,7 @@ use crate::rule::{
DynamicRuleContext, InputVisitable, Rule,
};
use crate::synchronicity::{Asynchronous, Synchronicity};
use crate::{Input, InputVisitor, NodeId, Synchronous};
use crate::{Graph, Input, InputVisitor, InvalidationSignal, NodeGraph, NodeId, Synchronous};
use quote::ToTokens;
use std::any::Any;
use std::cell::{Cell, RefCell};
@ -25,14 +25,18 @@ pub(crate) struct ErasedNode<Synch: Synchronicity> {
}
pub(crate) struct NodeUpdateContext<Synch: Synchronicity> {
pub(crate) graph: Rc<RefCell<NodeGraph<Synch>>>,
pub(crate) graph_is_valid: Rc<Cell<bool>>,
pub(crate) invalidate_dependent_nodes: bool,
pub(crate) removed_nodes: Vec<NodeId>,
pub(crate) added_nodes: Vec<(ErasedNode<Synch>, Rc<Cell<Option<NodeId>>>)>,
}
impl<S: Synchronicity> NodeUpdateContext<S> {
pub(crate) fn new() -> Self {
pub(crate) fn new<O>(graph: &Graph<O, S>) -> Self {
Self {
graph: Rc::clone(&graph.node_graph),
graph_is_valid: Rc::clone(&graph.is_valid),
invalidate_dependent_nodes: false,
removed_nodes: vec![],
added_nodes: vec![],
@ -562,6 +566,19 @@ impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S
{
self.add_node(RuleNode::new(rule))
}
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
where
R: Rule,
{
let input = self.add_rule(rule);
let signal = InvalidationSignal::new(
&input,
Rc::clone(&self.0.graph),
Rc::clone(&self.0.graph_is_valid),
);
(input, signal)
}
}
struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R);
@ -658,6 +675,13 @@ impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
{
DynamicRuleUpdateContext(self.0).add_rule(rule)
}
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
where
R: Rule,
{
DynamicRuleUpdateContext(self.0).add_invalidatable_rule(rule)
}
}
impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {

View File

@ -1,5 +1,5 @@
use crate::node::{DynamicRuleOutput, NodeValue};
use crate::NodeId;
use crate::{InvalidationSignal, NodeId};
pub use compute_graph_macros::InputVisitable;
use std::cell::{Cell, Ref, RefCell};
use std::collections::{HashMap, HashSet};
@ -116,6 +116,11 @@ pub trait DynamicRuleContext {
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule;
/// Adds an externally-invalidatable node whose value is produced using the given rule.
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
where
R: Rule;
}
/// Helper type for working with [`DynamicRule`]s.
@ -143,13 +148,13 @@ impl<ID: Hash + Eq + Clone, ChildOutput> DynamicNodeFactory<ID, ChildOutput> {
///
/// This method must be called for every node that is part of the output. The `build` function
/// will only be called for nodes that have not previously been built.
pub fn add_rule<F, R>(&mut self, ctx: &mut impl DynamicRuleContext, id: ID, build: F)
pub fn add_rule<C, F>(&mut self, ctx: &mut C, id: ID, build: F)
where
F: FnOnce() -> R,
R: Rule<Output = ChildOutput>,
C: DynamicRuleContext,
F: FnOnce(&mut C) -> Input<ChildOutput>,
{
if !self.existing_nodes.contains_key(&id) {
let input = ctx.add_rule(build());
let input = build(ctx);
self.existing_nodes.insert(id.clone(), input);
}
self.ids_added_this_evaluation.insert(id);
@ -158,13 +163,13 @@ impl<ID: Hash + Eq + Clone, ChildOutput> DynamicNodeFactory<ID, ChildOutput> {
/// Registers a node that is part of the output.
///
/// See [`DynamicNodeFactory::add_rule`].
pub fn add_async_rule<F, R>(&mut self, ctx: &mut impl AsyncDynamicRuleContext, id: ID, build: F)
pub fn add_async_rule<C, F>(&mut self, ctx: &mut C, id: ID, build: F)
where
F: FnOnce() -> R,
R: AsyncRule<Output = ChildOutput>,
C: AsyncDynamicRuleContext,
F: FnOnce(&mut C) -> Input<ChildOutput>,
{
if !self.existing_nodes.contains_key(&id) {
let input = ctx.add_async_rule(build());
let input = build(ctx);
self.existing_nodes.insert(id.clone(), input);
}
self.ids_added_this_evaluation.insert(id);

View File

@ -94,8 +94,9 @@ impl DynamicRule for MakeReadNodes {
type ChildOutput = ReadPostOutput;
fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<Self::ChildOutput>> {
for file in self.files.value().iter() {
self.node_factory
.add_rule(ctx, file.clone(), || ReadPost { path: file.clone() });
self.node_factory.add_rule(ctx, file.clone(), |ctx| {
ctx.add_rule(ReadPost { path: file.clone() })
});
}
self.node_factory.all_nodes(ctx)
}
@ -154,8 +155,8 @@ impl DynamicRule for MakeExtractMetadatas {
for post_input in self.posts.value().inputs.iter() {
let post_ = post_input.value();
let post = post_.as_ref().unwrap();
self.node_factory.add_rule(ctx, post.path.clone(), || {
ExtractMetadata(post_input.clone())
self.node_factory.add_rule(ctx, post.path.clone(), |ctx| {
ctx.add_rule(ExtractMetadata(post_input.clone()))
});
}
self.node_factory.all_nodes(ctx)
@ -200,8 +201,9 @@ impl DynamicRule for MakeWritePosts {
fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<Self::ChildOutput>> {
for post_input in self.posts.value().inputs.iter() {
if let Some(post) = post_input.value().as_ref() {
self.node_factory
.add_rule(ctx, post.path.clone(), || WritePost(post_input.clone()));
self.node_factory.add_rule(ctx, post.path.clone(), |ctx| {
ctx.add_rule(WritePost(post_input.clone()))
});
}
}
self.node_factory.all_nodes(ctx)

View File

@ -51,11 +51,12 @@ impl DynamicRule for MakePostsByTags {
}
}
for (slug, name) in all_tags {
self.node_factory
.add_rule(ctx, slug.clone(), || PostsByTag {
self.node_factory.add_rule(ctx, slug.clone(), |ctx| {
ctx.add_rule(PostsByTag {
posts: self.posts.clone(),
tag: Tag { slug, name },
});
})
});
}
self.node_factory.all_nodes(ctx)
}
@ -133,8 +134,8 @@ impl DynamicRule for MakeWriteTagPages {
for tag_input in self.tags.value().inputs.iter() {
let tag_and_posts = tag_input.value();
self.node_factory
.add_rule(ctx, tag_and_posts.tag.slug.clone(), || {
WriteTag(tag_input.clone())
.add_rule(ctx, tag_and_posts.tag.slug.clone(), |ctx| {
ctx.add_rule(WriteTag(tag_input.clone()))
});
}
self.node_factory.all_nodes(ctx)