Allow dynamic nodes to add invalidatable rules
This commit is contained in:
parent
f44f525c2c
commit
640c0ab620
@ -179,17 +179,9 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
}
|
||||
|
||||
fn make_invalidation_signal<V>(&self, input: &Input<V>) -> InvalidationSignal {
|
||||
let node_idx = Rc::clone(&input.node_idx);
|
||||
let graph = Rc::clone(&self.node_graph);
|
||||
let graph_is_valid = Rc::clone(&self.is_valid);
|
||||
InvalidationSignal {
|
||||
do_invalidate: Rc::new(Box::new(move || {
|
||||
graph_is_valid.set(false);
|
||||
let mut graph = graph.borrow_mut();
|
||||
let node = &mut graph[node_idx.get().unwrap()];
|
||||
node.invalidate();
|
||||
})),
|
||||
}
|
||||
InvalidationSignal::new(input, graph, graph_is_valid)
|
||||
}
|
||||
|
||||
/// Adds a node to the graph whose output is additional nodes produced by the given rule.
|
||||
|
@ -336,7 +336,7 @@ impl<O: 'static> Graph<O, Synchronous> {
|
||||
let node = &mut graph[idx];
|
||||
if !node.is_valid() {
|
||||
// Update this node
|
||||
let mut ctx = NodeUpdateContext::new();
|
||||
let mut ctx = NodeUpdateContext::new(self);
|
||||
node.update(&mut ctx);
|
||||
|
||||
drop(graph);
|
||||
@ -389,7 +389,7 @@ impl<O: 'static> Graph<O, Asynchronous> {
|
||||
let node = &mut graph[idx];
|
||||
if !node.is_valid() {
|
||||
// Update this node
|
||||
let mut ctx = NodeUpdateContext::new();
|
||||
let mut ctx = NodeUpdateContext::new(self);
|
||||
node.update(&mut ctx).await;
|
||||
|
||||
drop(graph);
|
||||
@ -444,6 +444,22 @@ pub struct InvalidationSignal {
|
||||
}
|
||||
|
||||
impl InvalidationSignal {
|
||||
pub(crate) fn new<V, S: Synchronicity>(
|
||||
input: &Input<V>,
|
||||
graph: Rc<RefCell<NodeGraph<S>>>,
|
||||
graph_is_valid: Rc<Cell<bool>>,
|
||||
) -> Self {
|
||||
let node_idx = Rc::clone(&input.node_idx);
|
||||
InvalidationSignal {
|
||||
do_invalidate: Rc::new(Box::new(move || {
|
||||
graph_is_valid.set(false);
|
||||
let mut graph = graph.borrow_mut();
|
||||
let node = &mut graph[node_idx.get().unwrap()];
|
||||
node.invalidate();
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell the graph that the node corresponding to this signal is now invalid.
|
||||
///
|
||||
/// Note: Calling this method does not trigger a graph evaluation, it merely marks the corresponding
|
||||
@ -724,10 +740,6 @@ mod tests {
|
||||
assert_eq!(*graph.evaluate(), NonCloneable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn only_update_downstream_nodes_if_value_changes() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
|
||||
struct IncAdd(Input<i32>, i32);
|
||||
impl InputVisitable for IncAdd {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
@ -741,6 +753,11 @@ mod tests {
|
||||
*self.0.value() + self.1
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn only_update_downstream_nodes_if_value_changes() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
|
||||
builder.set_output(IncAdd(a, 0));
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate(), 1);
|
||||
@ -811,6 +828,19 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
struct DynamicSum(DynamicInput<i32>);
|
||||
impl InputVisitable for DynamicSum {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
visitor.visit_dynamic(&self.0);
|
||||
}
|
||||
}
|
||||
impl Rule for DynamicSum {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
self.0.value().inputs.iter().map(|i| *i.value()).sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_rule() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
@ -832,7 +862,8 @@ mod tests {
|
||||
) -> Vec<Input<Self::ChildOutput>> {
|
||||
let count = *self.count.value();
|
||||
for i in 1..=count {
|
||||
self.node_factory.add_rule(ctx, i, || ConstantRule::new(i));
|
||||
self.node_factory
|
||||
.add_rule(ctx, i, |ctx| ctx.add_rule(ConstantRule::new(i)));
|
||||
}
|
||||
self.node_factory.all_nodes(ctx)
|
||||
}
|
||||
@ -841,19 +872,7 @@ mod tests {
|
||||
count,
|
||||
node_factory: DynamicNodeFactory::new(),
|
||||
});
|
||||
struct Sum(DynamicInput<i32>);
|
||||
impl InputVisitable for Sum {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
visitor.visit_dynamic(&self.0);
|
||||
}
|
||||
}
|
||||
impl Rule for Sum {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
self.0.value().inputs.iter().map(|i| *i.value()).sum()
|
||||
}
|
||||
}
|
||||
builder.set_output(Sum(all_inputs));
|
||||
builder.set_output(DynamicSum(all_inputs));
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate(), 1);
|
||||
set_count.set_value(2);
|
||||
@ -891,19 +910,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
let all_inputs = builder.add_async_dynamic_rule(CountUpTo(count, vec![]));
|
||||
struct Sum(DynamicInput<i32>);
|
||||
impl InputVisitable for Sum {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
visitor.visit_dynamic(&self.0);
|
||||
}
|
||||
}
|
||||
impl Rule for Sum {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
self.0.value().inputs.iter().map(|i| *i.value()).sum()
|
||||
}
|
||||
}
|
||||
builder.set_output(Sum(all_inputs));
|
||||
builder.set_output(DynamicSum(all_inputs));
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate_async().await, 1);
|
||||
set_count.set_value(2);
|
||||
@ -912,4 +919,57 @@ mod tests {
|
||||
assert_eq!(*graph.evaluate_async().await, 10);
|
||||
println!("{}", graph.as_dot_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_invalidatable_rule() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let (count, set_count) = builder.add_invalidatable_value(1);
|
||||
struct CountUpTo {
|
||||
count: Input<i32>,
|
||||
signals: Rc<RefCell<Vec<InvalidationSignal>>>,
|
||||
node_factory: DynamicNodeFactory<i32, i32>,
|
||||
}
|
||||
impl InputVisitable for CountUpTo {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
visitor.visit(&self.count);
|
||||
}
|
||||
}
|
||||
impl DynamicRule for CountUpTo {
|
||||
type ChildOutput = i32;
|
||||
fn evaluate(
|
||||
&mut self,
|
||||
ctx: &mut impl rule::DynamicRuleContext,
|
||||
) -> Vec<Input<Self::ChildOutput>> {
|
||||
let count = *self.count.value();
|
||||
for i in 1..=count {
|
||||
self.node_factory.add_rule(ctx, i, |ctx| {
|
||||
let constant = ctx.add_rule(ConstantRule::new(i));
|
||||
let (input, signal) = ctx.add_invalidatable_rule(IncAdd(constant, 0));
|
||||
self.signals.borrow_mut().push(signal);
|
||||
input
|
||||
});
|
||||
}
|
||||
self.node_factory.all_nodes(ctx)
|
||||
}
|
||||
}
|
||||
let signals = Rc::new(RefCell::new(vec![]));
|
||||
let all_inputs = builder.add_dynamic_rule(CountUpTo {
|
||||
count,
|
||||
signals: Rc::clone(&signals),
|
||||
node_factory: DynamicNodeFactory::new(),
|
||||
});
|
||||
builder.set_output(DynamicSum(all_inputs));
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate(), 2);
|
||||
for signal in signals.borrow().iter() {
|
||||
signal.invalidate();
|
||||
}
|
||||
assert_eq!(*graph.evaluate(), 3);
|
||||
set_count.set_value(2);
|
||||
assert_eq!(*graph.evaluate(), 6); // new const node has value 2, IncAdd initially adds 1
|
||||
for signal in signals.borrow().iter() {
|
||||
signal.invalidate();
|
||||
}
|
||||
assert_eq!(*graph.evaluate(), 8);
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ use crate::rule::{
|
||||
DynamicRuleContext, InputVisitable, Rule,
|
||||
};
|
||||
use crate::synchronicity::{Asynchronous, Synchronicity};
|
||||
use crate::{Input, InputVisitor, NodeId, Synchronous};
|
||||
use crate::{Graph, Input, InputVisitor, InvalidationSignal, NodeGraph, NodeId, Synchronous};
|
||||
use quote::ToTokens;
|
||||
use std::any::Any;
|
||||
use std::cell::{Cell, RefCell};
|
||||
@ -25,14 +25,18 @@ pub(crate) struct ErasedNode<Synch: Synchronicity> {
|
||||
}
|
||||
|
||||
pub(crate) struct NodeUpdateContext<Synch: Synchronicity> {
|
||||
pub(crate) graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||
pub(crate) graph_is_valid: Rc<Cell<bool>>,
|
||||
pub(crate) invalidate_dependent_nodes: bool,
|
||||
pub(crate) removed_nodes: Vec<NodeId>,
|
||||
pub(crate) added_nodes: Vec<(ErasedNode<Synch>, Rc<Cell<Option<NodeId>>>)>,
|
||||
}
|
||||
|
||||
impl<S: Synchronicity> NodeUpdateContext<S> {
|
||||
pub(crate) fn new() -> Self {
|
||||
pub(crate) fn new<O>(graph: &Graph<O, S>) -> Self {
|
||||
Self {
|
||||
graph: Rc::clone(&graph.node_graph),
|
||||
graph_is_valid: Rc::clone(&graph.is_valid),
|
||||
invalidate_dependent_nodes: false,
|
||||
removed_nodes: vec![],
|
||||
added_nodes: vec![],
|
||||
@ -562,6 +566,19 @@ impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S
|
||||
{
|
||||
self.add_node(RuleNode::new(rule))
|
||||
}
|
||||
|
||||
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
|
||||
where
|
||||
R: Rule,
|
||||
{
|
||||
let input = self.add_rule(rule);
|
||||
let signal = InvalidationSignal::new(
|
||||
&input,
|
||||
Rc::clone(&self.0.graph),
|
||||
Rc::clone(&self.0.graph_is_valid),
|
||||
);
|
||||
(input, signal)
|
||||
}
|
||||
}
|
||||
|
||||
struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R);
|
||||
@ -658,6 +675,13 @@ impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
|
||||
{
|
||||
DynamicRuleUpdateContext(self.0).add_rule(rule)
|
||||
}
|
||||
|
||||
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
|
||||
where
|
||||
R: Rule,
|
||||
{
|
||||
DynamicRuleUpdateContext(self.0).add_invalidatable_rule(rule)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::node::{DynamicRuleOutput, NodeValue};
|
||||
use crate::NodeId;
|
||||
use crate::{InvalidationSignal, NodeId};
|
||||
pub use compute_graph_macros::InputVisitable;
|
||||
use std::cell::{Cell, Ref, RefCell};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@ -116,6 +116,11 @@ pub trait DynamicRuleContext {
|
||||
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||
where
|
||||
R: Rule;
|
||||
|
||||
/// Adds an externally-invalidatable node whose value is produced using the given rule.
|
||||
fn add_invalidatable_rule<R>(&mut self, rule: R) -> (Input<R::Output>, InvalidationSignal)
|
||||
where
|
||||
R: Rule;
|
||||
}
|
||||
|
||||
/// Helper type for working with [`DynamicRule`]s.
|
||||
@ -143,13 +148,13 @@ impl<ID: Hash + Eq + Clone, ChildOutput> DynamicNodeFactory<ID, ChildOutput> {
|
||||
///
|
||||
/// This method must be called for every node that is part of the output. The `build` function
|
||||
/// will only be called for nodes that have not previously been built.
|
||||
pub fn add_rule<F, R>(&mut self, ctx: &mut impl DynamicRuleContext, id: ID, build: F)
|
||||
pub fn add_rule<C, F>(&mut self, ctx: &mut C, id: ID, build: F)
|
||||
where
|
||||
F: FnOnce() -> R,
|
||||
R: Rule<Output = ChildOutput>,
|
||||
C: DynamicRuleContext,
|
||||
F: FnOnce(&mut C) -> Input<ChildOutput>,
|
||||
{
|
||||
if !self.existing_nodes.contains_key(&id) {
|
||||
let input = ctx.add_rule(build());
|
||||
let input = build(ctx);
|
||||
self.existing_nodes.insert(id.clone(), input);
|
||||
}
|
||||
self.ids_added_this_evaluation.insert(id);
|
||||
@ -158,13 +163,13 @@ impl<ID: Hash + Eq + Clone, ChildOutput> DynamicNodeFactory<ID, ChildOutput> {
|
||||
/// Registers a node that is part of the output.
|
||||
///
|
||||
/// See [`DynamicNodeFactory::add_rule`].
|
||||
pub fn add_async_rule<F, R>(&mut self, ctx: &mut impl AsyncDynamicRuleContext, id: ID, build: F)
|
||||
pub fn add_async_rule<C, F>(&mut self, ctx: &mut C, id: ID, build: F)
|
||||
where
|
||||
F: FnOnce() -> R,
|
||||
R: AsyncRule<Output = ChildOutput>,
|
||||
C: AsyncDynamicRuleContext,
|
||||
F: FnOnce(&mut C) -> Input<ChildOutput>,
|
||||
{
|
||||
if !self.existing_nodes.contains_key(&id) {
|
||||
let input = ctx.add_async_rule(build());
|
||||
let input = build(ctx);
|
||||
self.existing_nodes.insert(id.clone(), input);
|
||||
}
|
||||
self.ids_added_this_evaluation.insert(id);
|
||||
|
@ -94,8 +94,9 @@ impl DynamicRule for MakeReadNodes {
|
||||
type ChildOutput = ReadPostOutput;
|
||||
fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<Self::ChildOutput>> {
|
||||
for file in self.files.value().iter() {
|
||||
self.node_factory
|
||||
.add_rule(ctx, file.clone(), || ReadPost { path: file.clone() });
|
||||
self.node_factory.add_rule(ctx, file.clone(), |ctx| {
|
||||
ctx.add_rule(ReadPost { path: file.clone() })
|
||||
});
|
||||
}
|
||||
self.node_factory.all_nodes(ctx)
|
||||
}
|
||||
@ -154,8 +155,8 @@ impl DynamicRule for MakeExtractMetadatas {
|
||||
for post_input in self.posts.value().inputs.iter() {
|
||||
let post_ = post_input.value();
|
||||
let post = post_.as_ref().unwrap();
|
||||
self.node_factory.add_rule(ctx, post.path.clone(), || {
|
||||
ExtractMetadata(post_input.clone())
|
||||
self.node_factory.add_rule(ctx, post.path.clone(), |ctx| {
|
||||
ctx.add_rule(ExtractMetadata(post_input.clone()))
|
||||
});
|
||||
}
|
||||
self.node_factory.all_nodes(ctx)
|
||||
@ -200,8 +201,9 @@ impl DynamicRule for MakeWritePosts {
|
||||
fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<Self::ChildOutput>> {
|
||||
for post_input in self.posts.value().inputs.iter() {
|
||||
if let Some(post) = post_input.value().as_ref() {
|
||||
self.node_factory
|
||||
.add_rule(ctx, post.path.clone(), || WritePost(post_input.clone()));
|
||||
self.node_factory.add_rule(ctx, post.path.clone(), |ctx| {
|
||||
ctx.add_rule(WritePost(post_input.clone()))
|
||||
});
|
||||
}
|
||||
}
|
||||
self.node_factory.all_nodes(ctx)
|
||||
|
@ -51,10 +51,11 @@ impl DynamicRule for MakePostsByTags {
|
||||
}
|
||||
}
|
||||
for (slug, name) in all_tags {
|
||||
self.node_factory
|
||||
.add_rule(ctx, slug.clone(), || PostsByTag {
|
||||
self.node_factory.add_rule(ctx, slug.clone(), |ctx| {
|
||||
ctx.add_rule(PostsByTag {
|
||||
posts: self.posts.clone(),
|
||||
tag: Tag { slug, name },
|
||||
})
|
||||
});
|
||||
}
|
||||
self.node_factory.all_nodes(ctx)
|
||||
@ -133,8 +134,8 @@ impl DynamicRule for MakeWriteTagPages {
|
||||
for tag_input in self.tags.value().inputs.iter() {
|
||||
let tag_and_posts = tag_input.value();
|
||||
self.node_factory
|
||||
.add_rule(ctx, tag_and_posts.tag.slug.clone(), || {
|
||||
WriteTag(tag_input.clone())
|
||||
.add_rule(ctx, tag_and_posts.tag.slug.clone(), |ctx| {
|
||||
ctx.add_rule(WriteTag(tag_input.clone()))
|
||||
});
|
||||
}
|
||||
self.node_factory.all_nodes(ctx)
|
||||
|
Loading…
x
Reference in New Issue
Block a user