From d92ebf11b2c2b5e68d438c33409cd4806b2a1cc4 Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Sun, 29 Dec 2024 13:37:54 -0500 Subject: [PATCH] Dynamic rules --- crates/compute_graph/src/builder.rs | 74 ++++- crates/compute_graph/src/lib.rs | 247 +++++++++++++-- crates/compute_graph/src/node.rs | 356 +++++++++++++++++++--- crates/compute_graph/src/rule.rs | 101 +++++- crates/compute_graph/src/synchronicity.rs | 16 +- crates/compute_graph_macros/src/lib.rs | 69 +++-- crates/derive_test/src/lib.rs | 43 ++- 7 files changed, 803 insertions(+), 103 deletions(-) diff --git a/crates/compute_graph/src/builder.rs b/crates/compute_graph/src/builder.rs index 8cd77bb..5fbd43e 100644 --- a/crates/compute_graph/src/builder.rs +++ b/crates/compute_graph/src/builder.rs @@ -1,8 +1,8 @@ use crate::node::{ - AsyncConstNode, AsyncRuleNode, ConstNode, ErasedNode, InvalidatableConstNode, Node, NodeValue, - RuleNode, + AsyncConstNode, AsyncDynamicRuleNode, AsyncRuleNode, ConstNode, DynamicRuleNode, ErasedNode, + InvalidatableConstNode, Node, NodeValue, RuleNode, }; -use crate::rule::{AsyncRule, Input, Rule}; +use crate::rule::{AsyncDynamicRule, AsyncRule, DynamicInput, DynamicRule, Input, Rule}; use crate::synchronicity::{Asynchronous, Synchronicity, Synchronous}; use crate::util; use crate::{Graph, InvalidationSignal, NodeGraph, NodeId, ValueInvalidationSignal}; @@ -73,7 +73,7 @@ impl GraphBuilder { let erased = ErasedNode::new(node); let idx = self.node_graph.borrow_mut().add_node(erased); Input { - node_idx: idx, + node_idx: Rc::new(Cell::new(Some(idx))), value, } } @@ -174,19 +174,42 @@ impl GraphBuilder { } fn make_invalidation_signal(&self, input: &Input) -> InvalidationSignal { - let node_idx = input.node_idx; + let node_idx = Rc::clone(&input.node_idx); let graph = Rc::clone(&self.node_graph); let graph_is_valid = Rc::clone(&self.is_valid); InvalidationSignal { do_invalidate: Rc::new(Box::new(move || { graph_is_valid.set(false); let mut graph = graph.borrow_mut(); - let node = &mut graph[node_idx]; + let node = &mut graph[node_idx.get().unwrap()]; node.invalidate(); })), } } + /// Adds a node to the graph whose output is additional nodes produced by the given rule. + pub fn add_dynamic_rule(&mut self, rule: R) -> DynamicInput + where + R: DynamicRule, + { + let input = self.add_node(DynamicRuleNode::::new(rule)); + DynamicInput { input } + } + + /// Adds an externally-invalidatable node to the graph whose output is additional + /// nodes produced by the given rule. + pub fn add_invalidatable_dynamic_rule( + &mut self, + rule: R, + ) -> (DynamicInput, InvalidationSignal) + where + R: DynamicRule, + { + let input = self.add_dynamic_rule(rule); + let signal = self.make_invalidation_signal(&input.input); + (input, signal) + } + /// Creates a graph from this builder, consuming the builder. /// /// To successfully build a graph, there must be an output node set (using either @@ -217,7 +240,7 @@ impl GraphBuilder { graph.add_edge(source, dest, ()); } - util::remove_nodes_not_connected_to(&mut *graph, output.node_idx); + util::remove_nodes_not_connected_to(&mut *graph, output.node_idx.get().unwrap()); drop(graph); @@ -319,6 +342,29 @@ impl GraphBuilder { let signal = self.make_invalidation_signal(&input); (input, signal) } + + /// Adds a node to the graph whose output is additional nodes produced asynchronously by the given rule. + pub fn add_async_dynamic_rule(&mut self, rule: R) -> DynamicInput + where + R: AsyncDynamicRule, + { + let input = self.add_node(AsyncDynamicRuleNode::::new(rule)); + DynamicInput { input } + } + + /// Adds an externally-invalidatable node to the graph whose output is additional nodes produced + /// asynchronously by the given rule. + pub fn add_invalidatable_async_dynamic_rule( + &mut self, + rule: R, + ) -> (DynamicInput, InvalidationSignal) + where + R: AsyncDynamicRule, + { + let input = self.add_async_dynamic_rule(rule); + let signal = self.make_invalidation_signal(&input.input); + (input, signal) + } } /// A reason why a [`GraphBuilder`] can fail to build a graph. @@ -383,8 +429,18 @@ mod tests { builder.set_output(Double::new(b.clone())); match builder.build() { Err(super::BuildGraphError::Cycle(cycle)) => { - let a_start = cycle == vec![a.node_idx, b.node_idx, a.node_idx]; - let b_start = cycle == vec![b.node_idx, a.node_idx, b.node_idx]; + let a_start = cycle + == vec![ + a.node_idx.get().unwrap(), + b.node_idx.get().unwrap(), + a.node_idx.get().unwrap(), + ]; + let b_start = cycle + == vec![ + b.node_idx.get().unwrap(), + a.node_idx.get().unwrap(), + b.node_idx.get().unwrap(), + ]; // either is a permisisble way of describing the cycle assert!(a_start || b_start); } diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index 3f744c4..a0a7f0a 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -49,10 +49,10 @@ pub mod synchronicity; mod util; use builder::{BuildGraphError, GraphBuilder}; -use node::{ErasedNode, NodeValue}; +use node::{ErasedNode, NodeUpdateContext, NodeValue}; use petgraph::visit::{IntoEdgeReferences, IntoNodeReferences, NodeIndexable, NodeRef}; use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef}; -use rule::{AsyncRule, Input, InputVisitor, Rule}; +use rule::{Input, InputVisitor}; use std::cell::{Cell, RefCell}; use std::collections::HashMap; use std::collections::VecDeque; @@ -127,7 +127,15 @@ impl Graph { /// Because building a graph can fail and this method mutates the underlying graph, it takes /// ownership of the current graph to prevent the graph being left in an invalid state. /// It returns either the new, modified graph or an error. - pub fn modify(mut self, mut f: F) -> Result + pub fn modify(mut self, f: F) -> Result + where + F: FnMut(&mut GraphBuilder) -> (), + { + self._modify(f)?; + Ok(self) + } + + fn _modify(&mut self, mut f: F) -> Result<(), BuildGraphError> where F: FnMut(&mut GraphBuilder) -> (), { @@ -142,12 +150,12 @@ impl Graph { } drop(graph); - let old_output = self.output.node_idx; + let old_output = self.output.node_idx.get(); // Modify - let mut builder = self.into_builder(); + let mut builder = self.to_builder(); f(&mut builder); - self = builder.build()?; + *self = builder.build()?; // Any new inboud edges invalidate their target nodes. let mut graph = self.node_graph.borrow_mut(); @@ -164,7 +172,7 @@ impl Graph { } // Edge case: if the only node in the graph is the output node, and it's replaced in the modify block, // there are no edges but we still need to invalidate. - if !to_invalidate.is_empty() || self.output.node_idx != old_output { + if !to_invalidate.is_empty() || self.output.node_idx.get() != old_output { self.is_valid.set(false); for idx in to_invalidate { let node = &mut graph[idx]; @@ -173,13 +181,17 @@ impl Graph { } drop(graph); - Ok(self) + Ok(()) } /// Convert this graph back into a builder for further modifications. /// /// Returns a builder with the same output and synchronicity types. pub fn into_builder(self) -> GraphBuilder { + self.to_builder() + } + + fn to_builder(&self) -> GraphBuilder { // Clear the edges before modifying so that rebuilding results in a graph with up-to-date edges. let mut graph = self.node_graph.borrow_mut(); graph.clear_edges(); @@ -232,7 +244,7 @@ impl Graph { for node in self.0.node_references() { let id = self.0.to_index(node.id()); let label = Escaped(node.weight()); - writeln!(f, "\t{id} [label =\"{label:?} (id={id})\"]")?; + writeln!(f, "\t{id} [label=\"{label:?} (id={id})\"]")?; } for edge in self.0.edge_references() { let source = self.0.to_index(edge.source()); @@ -250,13 +262,51 @@ impl Graph { impl Graph { fn update_invalid_nodes(&mut self) { let mut graph = self.node_graph.borrow_mut(); - for &idx in self.sorted_nodes.iter() { + let mut i = 0; + while i < self.sorted_nodes.len() { + let idx = self.sorted_nodes[i]; let node = &mut graph[idx]; if !node.is_valid() { // Update this node - let value_changed = node.update(); + let mut ctx = NodeUpdateContext::new(); + node.update(&mut ctx); - if value_changed { + let mut nodes_changed = false; + for idx_to_remove in ctx.removed_nodes { + assert!( + idx_to_remove != idx, + "cannot remove node curently being evaluated" + ); + let (index_to_remove_in_sorted, _) = self + .sorted_nodes + .iter() + .enumerate() + .find(|(_, idx)| **idx == idx_to_remove) + .expect("removed node must have been already added"); + assert!( + index_to_remove_in_sorted > i, + "cannot remove already evaluated node" + ); + graph.remove_node(idx_to_remove); + self.sorted_nodes.remove(index_to_remove_in_sorted); + nodes_changed = true; + } + + for (added_node, id_cell) in ctx.added_nodes { + let id = graph.add_node(added_node); + id_cell.set(Some(id)); + nodes_changed = true; + } + + if nodes_changed { + // Update the graph before invalidating downstream nodes. + drop(graph); + self._modify(|_| {}) + .expect("modifying graph during evaluation must produce valid graph"); + graph = self.node_graph.borrow_mut(); + } + + if ctx.invalidate_dependent_nodes { // Invalidate any downstream nodes (which we know we haven't visited yet, because // we're iterating over a topological sort of the graph). let dependents = graph @@ -270,14 +320,25 @@ impl Graph { dependent.invalidate(); } } + + if nodes_changed { + // If we added/removed nodes, the sorted order has changed, so start evaluating + // from the beginning, in case of changes before i. + i = 0; + continue; + } } + + i += 1; } + // Consistency check: after updating in the topological sort order, we should be left with - // no invalid nodes + // no invalid nodes. debug_assert!(self .sorted_nodes .iter() .all(|&idx| { (&graph[idx]).is_valid() })); + self.is_valid.set(true); } @@ -300,13 +361,51 @@ impl Graph { async fn update_invalid_nodes(&mut self) { // TODO: consider whether this can be done in parallel to any degree. let mut graph = self.node_graph.borrow_mut(); - for &idx in self.sorted_nodes.iter() { + let mut i = 0; + while i < self.sorted_nodes.len() { + let idx = self.sorted_nodes[i]; let node = &mut graph[idx]; if !node.is_valid() { // Update this node - let value_changed = node.update().await; + let mut ctx = NodeUpdateContext::new(); + node.update(&mut ctx).await; - if value_changed { + let mut nodes_changed = false; + for idx_to_remove in ctx.removed_nodes { + assert!( + idx_to_remove != idx, + "cannot remove node curently being evaluated" + ); + let (index_to_remove_in_sorted, _) = self + .sorted_nodes + .iter() + .enumerate() + .find(|(_, idx)| **idx == idx_to_remove) + .expect("removed node must have been already added"); + assert!( + index_to_remove_in_sorted > i, + "cannot remove already evaluated node" + ); + graph.remove_node(idx_to_remove); + self.sorted_nodes.remove(index_to_remove_in_sorted); + nodes_changed = true; + } + + for (added_node, id_cell) in ctx.added_nodes { + let id = graph.add_node(added_node); + id_cell.set(Some(id)); + nodes_changed = true; + } + + if nodes_changed { + // Update the graph before invalidating downstream nodes. + drop(graph); + self._modify(|_| {}) + .expect("modifying graph during evaluation must produce valid graph"); + graph = self.node_graph.borrow_mut(); + } + + if ctx.invalidate_dependent_nodes { // Invalidate any downstream nodes (which we know we haven't visited yet, because // we're iterating over a topological sort of the graph). let dependents = graph @@ -320,14 +419,25 @@ impl Graph { dependent.invalidate(); } } + + if nodes_changed { + // If we added/removed nodes, the sorted order has changed, so start evaluating + // from the beginning, in case of changes before i. + i = 0; + continue; + } } + + i += 1; } + // Consistency check: after updating in the topological sort order, we should be left with // no invalid nodes debug_assert!(self .sorted_nodes .iter() .all(|&idx| { (&graph[idx]).is_valid() })); + self.is_valid.set(true); } @@ -420,7 +530,9 @@ impl Clone for ValueInvalidationSignal { #[cfg(test)] mod tests { use super::*; - use crate::rule::{ConstantRule, InputVisitable}; + use crate::rule::{ + AsyncDynamicRule, AsyncRule, ConstantRule, DynamicInput, DynamicRule, InputVisitable, Rule, + }; #[test] fn rule_output_with_no_inputs() { @@ -711,13 +823,108 @@ mod tests { assert_eq!( graph.as_dot_string(), r#"digraph { - 0 [label ="ConstNode (id=0)"] - 1 [label ="ConstNode (id=1)"] - 2 [label ="RuleNode(test) (id=2)"] + 0 [label="ConstNode (id=0)"] + 1 [label="ConstNode (id=1)"] + 2 [label="RuleNode(test) (id=2)"] 0 -> 2 [] 1 -> 2 [] } "# ) } + + #[test] + fn dynamic_rule() { + let mut builder = GraphBuilder::new(); + let (count, set_count) = builder.add_invalidatable_value(1); + struct CountUpTo(Input, Vec>); + impl InputVisitable for CountUpTo { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit(&self.0); + } + } + impl DynamicRule for CountUpTo { + type ChildOutput = i32; + fn evaluate( + &mut self, + ctx: &mut impl rule::DynamicRuleContext, + ) -> Vec> { + let count = *self.0.value(); + assert!(count >= self.1.len() as i32); + while (self.1.len() as i32) < count { + self.1 + .push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1))); + } + self.1.clone() + } + } + let all_inputs = builder.add_dynamic_rule(CountUpTo(count, vec![])); + struct Sum(DynamicInput); + impl InputVisitable for Sum { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit_dynamic(&self.0); + } + } + impl Rule for Sum { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + self.0.value().inputs.iter().map(|i| *i.value()).sum() + } + } + builder.set_output(Sum(all_inputs)); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate(), 1); + set_count.set_value(2); + assert_eq!(*graph.evaluate(), 3); + set_count.set_value(4); + assert_eq!(*graph.evaluate(), 10); + println!("{}", graph.as_dot_string()); + } + + #[tokio::test] + async fn async_dynamic_rule() { + let mut builder = GraphBuilder::new_async(); + let (count, set_count) = builder.add_invalidatable_value(1); + struct CountUpTo(Input, Vec>); + impl InputVisitable for CountUpTo { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit(&self.0); + } + } + impl AsyncDynamicRule for CountUpTo { + type ChildOutput = i32; + async fn evaluate<'a>( + &'a mut self, + ctx: &'a mut impl rule::AsyncDynamicRuleContext, + ) -> Vec> { + let count = *self.0.value(); + assert!(count >= self.1.len() as i32); + while (self.1.len() as i32) < count { + self.1 + .push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1))); + } + self.1.clone() + } + } + let all_inputs = builder.add_async_dynamic_rule(CountUpTo(count, vec![])); + struct Sum(DynamicInput); + impl InputVisitable for Sum { + fn visit_inputs(&self, visitor: &mut impl InputVisitor) { + visitor.visit_dynamic(&self.0); + } + } + impl Rule for Sum { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + self.0.value().inputs.iter().map(|i| *i.value()).sum() + } + } + builder.set_output(Sum(all_inputs)); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate_async().await, 1); + set_count.set_value(2); + assert_eq!(*graph.evaluate_async().await, 3); + set_count.set_value(4); + assert_eq!(*graph.evaluate_async().await, 10); + } } diff --git a/crates/compute_graph/src/node.rs b/crates/compute_graph/src/node.rs index be91a8c..f868ae4 100644 --- a/crates/compute_graph/src/node.rs +++ b/crates/compute_graph/src/node.rs @@ -1,8 +1,12 @@ +use crate::rule::{ + AsyncDynamicRule, AsyncDynamicRuleContext, AsyncRule, DynamicInput, DynamicRule, + DynamicRuleContext, InputVisitable, Rule, +}; use crate::synchronicity::{Asynchronous, Synchronicity}; -use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous}; +use crate::{Input, InputVisitor, NodeId, Synchronous}; use quote::ToTokens; use std::any::Any; -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use std::future::Future; use std::rc::Rc; @@ -11,10 +15,35 @@ pub(crate) struct ErasedNode { is_valid: Box) -> bool>, invalidate: Box) -> ()>, visit_inputs: Box, &mut dyn FnMut(NodeId) -> ()) -> ()>, - update: Box Fn(&'a mut Box) -> Synch::UpdateResult<'a>>, + update: Box< + dyn for<'a> Fn( + &'a mut Box, + &'a mut NodeUpdateContext, + ) -> Synch::UpdateResult<'a>, + >, debug_fmt: Box, &mut std::fmt::Formatter<'_>) -> std::fmt::Result>, } +pub(crate) struct NodeUpdateContext { + pub(crate) invalidate_dependent_nodes: bool, + pub(crate) removed_nodes: Vec, + pub(crate) added_nodes: Vec<(ErasedNode, Rc>>)>, +} + +impl NodeUpdateContext { + pub(crate) fn new() -> Self { + Self { + invalidate_dependent_nodes: false, + removed_nodes: vec![], + added_nodes: vec![], + } + } + + fn invalidate_dependent_nodes(&mut self) { + self.invalidate_dependent_nodes = true; + } +} + impl ErasedNode { pub(crate) fn new + 'static, V: NodeValue>(base: N) -> Self { // i don't love the double boxing, but i'm not sure how else to do this @@ -34,9 +63,9 @@ impl ErasedNode { let x = any.downcast_ref::>>().unwrap(); x.visit_inputs(visitor); }), - update: Box::new(|any| { + update: Box::new(|any, ctx| { let x = any.downcast_mut::>>().unwrap(); - x.update() + x.update(ctx) }), debug_fmt: Box::new(|any, f| { let x = any.downcast_ref::>>().unwrap(); @@ -57,14 +86,14 @@ impl ErasedNode { } impl ErasedNode { - pub(crate) fn update(&mut self) -> bool { - (self.update)(&mut self.any) + pub(crate) fn update(&mut self, ctx: &mut NodeUpdateContext) { + (self.update)(&mut self.any, ctx) } } impl ErasedNode { - pub(crate) async fn update(&mut self) -> bool { - (self.update)(&mut self.any).await + pub(crate) async fn update(&mut self, ctx: &mut NodeUpdateContext) { + (self.update)(&mut self.any, ctx).await } } @@ -78,7 +107,7 @@ pub(crate) trait Node: std::fmt::Debug { fn is_valid(&self) -> bool; fn invalidate(&mut self); fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()); - fn update(&mut self) -> Synch::UpdateResult<'_>; + fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext) -> Synch::UpdateResult<'a>; fn value_rc(&self) -> &Rc>>; } @@ -139,7 +168,7 @@ impl Node for ConstNode { fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {} - fn update(&mut self) -> S::UpdateResult<'_> { + fn update<'a>(&'a mut self, _ctx: &'a mut NodeUpdateContext) -> S::UpdateResult<'a> { unreachable!() } @@ -181,11 +210,12 @@ impl Node for InvalidatableConstNode fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {} - fn update(&mut self) -> S::UpdateResult<'_> { + fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext) -> S::UpdateResult<'a> { self.valid = true; // This node is only invalidate when node_value_eq between the old/new value is false, // so it is always the case that the update method has changed the value. - S::make_update_result(true, crate::synchronicity::private::Token) + ctx.invalidate_dependent_nodes(); + S::make_update_result(crate::synchronicity::private::Token) } fn value_rc(&self) -> &Rc>> { @@ -217,6 +247,32 @@ impl RuleNode { } } +fn visit_inputs(visitable: &V, visitor: &mut dyn FnMut(NodeId) -> ()) { + struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ()); + impl<'a> InputVisitor for InputIndexVisitor<'a> { + fn visit(&mut self, input: &Input) { + self.0(input.node_idx.get().unwrap()); + } + fn visit_dynamic(&mut self, input: &DynamicInput) { + // Visit the dynamic node itself + self.visit(&input.input); + + // And visit all the nodes it produces + let maybe_dynamic_output = input.input.value.borrow(); + if let Some(dynamic_output) = maybe_dynamic_output.as_ref() { + for input in dynamic_output.inputs.iter() { + self.visit(input); + } + } else { + // Haven't evaluated the dynamic node for the first time yet. + // Upon doing so, if the nodes it produces change, we'll modify the graph + // and end up back here in the other branch. + } + } + } + visitable.visit_inputs(&mut InputIndexVisitor(visitor)); +} + impl Node for RuleNode { fn is_valid(&self) -> bool { self.valid @@ -227,16 +283,10 @@ impl Node for RuleNode } fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { - struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ()); - impl<'a> InputVisitor for InputIndexVisitor<'a> { - fn visit(&mut self, input: &Input) { - self.0(input.node_idx); - } - } - self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); + visit_inputs(&self.rule, visitor); } - fn update(&mut self) -> S::UpdateResult<'_> { + fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext) -> S::UpdateResult<'a> { self.valid = true; let new_value = self.rule.evaluate(); @@ -247,9 +297,10 @@ impl Node for RuleNode if value_changed { *value = Some(new_value); + ctx.invalidate_dependent_nodes(); } - S::make_update_result(value_changed, crate::synchronicity::private::Token) + S::make_update_result(crate::synchronicity::private::Token) } fn value_rc(&self) -> &Rc>> { @@ -290,12 +341,12 @@ impl F, F: Future> AsyncConstNode { } } - async fn do_update(&mut self) -> bool { + async fn do_update(&mut self, ctx: &mut NodeUpdateContext) { self.valid = true; let mut provider = None; std::mem::swap(&mut self.provider, &mut provider); *self.value.borrow_mut() = Some(provider.unwrap()().await); - true + ctx.invalidate_dependent_nodes(); } } @@ -312,8 +363,11 @@ impl F, F: Future> Node ()) {} - fn update(&mut self) -> ::UpdateResult<'_> { - Box::pin(self.do_update()) + fn update<'a>( + &'a mut self, + ctx: &'a mut NodeUpdateContext, + ) -> ::UpdateResult<'a> { + Box::pin(self.do_update(ctx)) } fn value_rc(&self) -> &Rc>> { @@ -342,7 +396,7 @@ impl AsyncRuleNode { } } - async fn do_update(&mut self) -> bool { + async fn do_update(&mut self, ctx: &mut NodeUpdateContext) { self.valid = true; let new_value = self.rule.evaluate().await; @@ -353,9 +407,8 @@ impl AsyncRuleNode { if value_changed { *value = Some(new_value); + ctx.invalidate_dependent_nodes(); } - - value_changed } } @@ -369,17 +422,14 @@ impl Node for AsyncRuleNode } fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { - struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ()); - impl<'a> InputVisitor for InputIndexVisitor<'a> { - fn visit(&mut self, input: &Input) { - self.0(input.node_idx); - } - } - self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); + visit_inputs(&self.rule, visitor); } - fn update(&mut self) -> ::UpdateResult<'_> { - Box::pin(self.do_update()) + fn update<'a>( + &'a mut self, + ctx: &'a mut NodeUpdateContext, + ) -> ::UpdateResult<'a> { + Box::pin(self.do_update(ctx)) } fn value_rc(&self) -> &Rc>> { @@ -405,6 +455,236 @@ impl std::fmt::Debug for AsyncRuleNode { } } +// todo: better name for this +pub struct DynamicRuleOutput { + pub inputs: Vec>, +} + +impl NodeValue for DynamicRuleOutput { + fn node_value_eq(&self, other: &Self) -> bool { + if self.inputs.len() != other.inputs.len() { + return false; + } + self.inputs + .iter() + .zip(other.inputs.iter()) + .all(|(s, o)| s.node_idx == o.node_idx) + } +} + +impl std::fmt::Debug for DynamicRuleOutput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(std::any::type_name::()) + .field("inputs", &self.inputs) + .finish() + } +} + +pub(crate) struct DynamicRuleNode { + rule: R, + valid: bool, + value: Rc>>>, + synchronicity: std::marker::PhantomData, +} + +impl DynamicRuleNode { + pub(crate) fn new(rule: R) -> Self { + Self { + rule, + valid: false, + value: Rc::new(RefCell::new(None)), + synchronicity: std::marker::PhantomData, + } + } +} + +impl Node, S> + for DynamicRuleNode +{ + fn is_valid(&self) -> bool { + self.valid + } + + fn invalidate(&mut self) { + self.valid = false; + } + + fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { + visit_inputs(&self.rule, visitor); + } + + fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext) -> S::UpdateResult<'a> { + self.valid = true; + + let new_value = DynamicRuleOutput { + inputs: self.rule.evaluate(&mut DynamicRuleUpdateContext(ctx)), + }; + let mut value = self.value.borrow_mut(); + let value_changed = value + .as_ref() + .map_or(true, |v| !v.node_value_eq(&new_value)); + + if value_changed { + *value = Some(new_value); + ctx.invalidate_dependent_nodes(); + } + + S::make_update_result(crate::synchronicity::private::Token) + } + + fn value_rc(&self) -> &Rc>>> { + &self.value + } +} + +struct DynamicRuleUpdateContext<'a, Synch: Synchronicity>(&'a mut NodeUpdateContext); + +impl<'a, S: Synchronicity> DynamicRuleUpdateContext<'a, S> { + fn add_node(&mut self, node: impl Node + 'static) -> Input { + let node_idx = Rc::new(Cell::new(None)); + let value = Rc::clone(node.value_rc()); + let erased = ErasedNode::new(node); + self.0.added_nodes.push((erased, Rc::clone(&node_idx))); + Input { node_idx, value } + } +} + +impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S> { + fn remove_node(&mut self, id: NodeId) { + self.0.removed_nodes.push(id); + } + + fn add_rule(&mut self, rule: R) -> Input + where + R: Rule, + { + self.add_node(RuleNode::new(rule)) + } +} + +struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R); +impl<'a, R: DynamicRule> std::fmt::Display for DynamicRuleLabel<'a, R> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.node_label(f) + } +} + +impl std::fmt::Debug for DynamicRuleNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DynamicRuleNode<{}>({})", + pretty_type_name::(), + DynamicRuleLabel(&self.rule) + ) + } +} + +pub(crate) struct AsyncDynamicRuleNode { + rule: R, + valid: bool, + value: Rc>>>, +} + +impl AsyncDynamicRuleNode { + pub(crate) fn new(rule: R) -> Self { + Self { + rule, + valid: false, + value: Rc::new(RefCell::new(None)), + } + } + + async fn do_update(&mut self, ctx: &mut NodeUpdateContext) { + self.valid = true; + + let new_value = DynamicRuleOutput { + inputs: self + .rule + .evaluate(&mut AsyncDynamicRuleUpdateContext(ctx)) + .await, + }; + let mut value = self.value.borrow_mut(); + let value_changed = value + .as_ref() + .map_or(true, |v| !v.node_value_eq(&new_value)); + + if value_changed { + *value = Some(new_value); + ctx.invalidate_dependent_nodes(); + } + } +} + +impl Node, Asynchronous> + for AsyncDynamicRuleNode +{ + fn is_valid(&self) -> bool { + self.valid + } + + fn invalidate(&mut self) { + self.valid = false; + } + + fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { + visit_inputs(&self.rule, visitor); + } + + fn update<'a>( + &'a mut self, + ctx: &'a mut NodeUpdateContext, + ) -> ::UpdateResult<'a> { + Box::pin(self.do_update(ctx)) + } + + fn value_rc(&self) -> &Rc>>> { + &self.value + } +} + +struct AsyncDynamicRuleUpdateContext<'a>(&'a mut NodeUpdateContext); + +impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> { + fn remove_node(&mut self, id: NodeId) { + DynamicRuleUpdateContext(self.0).remove_node(id); + } + + fn add_rule(&mut self, rule: R) -> Input + where + R: Rule, + { + DynamicRuleUpdateContext(self.0).add_rule(rule) + } +} + +impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> { + fn add_async_rule(&mut self, rule: R) -> Input + where + R: AsyncRule, + { + DynamicRuleUpdateContext(self.0).add_node(AsyncRuleNode::new(rule)) + } +} + +struct AsyncDynamicRuleLabel<'a, R: AsyncDynamicRule>(&'a R); +impl<'a, R: AsyncDynamicRule> std::fmt::Display for AsyncDynamicRuleLabel<'a, R> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.node_label(f) + } +} + +impl std::fmt::Debug for AsyncDynamicRuleNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "AsyncDynamicRuleNode<{}>({})", + pretty_type_name::(), + AsyncDynamicRuleLabel(&self.rule) + ) + } +} + fn pretty_type_name() -> String { let s = std::any::type_name::(); let ty = syn::parse_str::(s).unwrap(); diff --git a/crates/compute_graph/src/rule.rs b/crates/compute_graph/src/rule.rs index 699b410..862ba12 100644 --- a/crates/compute_graph/src/rule.rs +++ b/crates/compute_graph/src/rule.rs @@ -1,7 +1,7 @@ -use crate::node::NodeValue; +use crate::node::{DynamicRuleOutput, NodeValue}; use crate::NodeId; pub use compute_graph_macros::InputVisitable; -use std::cell::{Ref, RefCell}; +use std::cell::{Cell, Ref, RefCell}; use std::future::Future; use std::ops::Deref; use std::rc::Rc; @@ -76,6 +76,75 @@ pub trait AsyncRule: InputVisitable + 'static { } } +/// A rule whose output is further nodes in the graph. +/// +/// Types implementing this rule should track which nodes they previously output and not +/// add additional equivalent nodes (for whatever domain-specific definition of equivalence) +/// on susbequent evaluations. +pub trait DynamicRule: InputVisitable + 'static { + /// The type of the output value of each of the child nodes that this rule produces. + type ChildOutput: NodeValue; + + /// Evaluates this rule, producing additional nodes. + /// + /// Use the methods on [`DynamicRuleContext`] to add or remove nodes from the graph. + fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec>; + + #[allow(unused_variables)] + fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + Ok(()) + } +} + +/// Facilities for adding/removing nodes in the graph during the update of a [`DynamicRule`]. +// todo: better abstracion for this +// something that handles diffing and does the add/remove automatically +pub trait DynamicRuleContext { + /// Removes the node with the given ID from the graph. + /// + /// Be careful when removing nodes. Removing a node that is still depended-upon by another node + /// (i.e., is an input in some other node's [`InputVisitable::visit_inputs`]) is an error. + fn remove_node(&mut self, id: NodeId); + + /// Adds a node whose value is produced using the given rule to the graph. + /// + /// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules. + fn add_rule(&mut self, rule: R) -> Input + where + R: Rule; +} + +/// An asynchronous rule whose output is further nodes in the graph. +/// +/// See [`DynamicRule`]. +pub trait AsyncDynamicRule: InputVisitable + 'static { + /// The type of the output value of each of the child nodes that this rule produces. + type ChildOutput: NodeValue; + + /// Evaluates this rule asynchronously, producing additional nodes. + /// + /// Use the methods on [`AsyncDynamicRuleContext`] to add or remove nodes from the graph. + fn evaluate<'a>( + &'a mut self, + ctx: &'a mut impl AsyncDynamicRuleContext, + ) -> impl Future>> + 'a; + + #[allow(unused_variables)] + fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + Ok(()) + } +} + +/// Facilities for adding/removing nodes in the graph during the update of an [`AsyncDynamicRule`]. +pub trait AsyncDynamicRuleContext: DynamicRuleContext { + /// Adds a node whose value is produced using the given rule to the graph. + /// + /// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules. + fn add_async_rule(&mut self, rule: R) -> Input + where + R: AsyncRule; +} + /// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited. /// /// The implementation of this trait can generally be derived using [`derive@InputVisitable`]. @@ -93,13 +162,13 @@ pub trait InputVisitable { fn visit_inputs(&self, visitor: &mut impl InputVisitor); } -/// A placeholder for the output of one node to be used as an input for another. +/// A placeholder for the output of one node, to be used as an input for another. /// /// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`). /// /// Note that this type implements `Clone`, so can be cloned and used as an input for multiple nodes. pub struct Input { - pub(crate) node_idx: NodeId, + pub(crate) node_idx: Rc>>, pub(crate) value: Rc>>, } @@ -119,7 +188,7 @@ impl Input { impl Clone for Input { fn clone(&self) -> Self { Self { - node_idx: self.node_idx, + node_idx: Rc::clone(&self.node_idx), value: Rc::clone(&self.value), } } @@ -136,6 +205,25 @@ impl std::fmt::Debug for Input { } } +/// A placeholder for the output of a dynamic rule node, to be used as an input for another. +/// +/// See [`GraphBuilder::add_dynamic_rule`](`crate::builder::GraphBuilder::add_dynamic_rule`). +/// +/// A dependency on a dynamic input represents both a dependency on the dynamic node itself, +/// as well as dependencies on each of the nodes that are the output of the dynamic node. +#[derive(Clone)] +pub struct DynamicInput { + pub(crate) input: Input>, +} + +impl DynamicInput { + /// Retrieves a reference to the current value of the dynamic node (i.e., the set of inputs + /// representing the nodes that are the outputs of the dynamic node). + pub fn value(&self) -> impl Deref> + '_ { + self.input.value() + } +} + // TODO: i really want Input to be able to implement Deref somehow /// A type that can visit arbitrary [`Input`]s. @@ -145,6 +233,9 @@ impl std::fmt::Debug for Input { pub trait InputVisitor { /// Visit an input whose value is of type `T`. fn visit(&mut self, input: &Input); + + /// Visit a dynamic input whose child value is of type `T`. + fn visit_dynamic(&mut self, input: &DynamicInput); } /// A simple rule that provides a constant value. diff --git a/crates/compute_graph/src/synchronicity.rs b/crates/compute_graph/src/synchronicity.rs index 90ab33c..59d2ec8 100644 --- a/crates/compute_graph/src/synchronicity.rs +++ b/crates/compute_graph/src/synchronicity.rs @@ -11,7 +11,7 @@ pub(crate) mod private { pub trait Sealed {} impl Sealed for super::Synchronous {} impl Sealed for super::Asynchronous {} - impl Sealed for bool {} + impl Sealed for () {} impl<'a> Sealed for ::UpdateResult<'a> {} pub struct Token; } @@ -20,25 +20,23 @@ pub trait Synchronicity: private::Sealed + 'static { type UpdateResult<'a>: private::Sealed; // Necessary for synchronous nodes that can be part of an async graph to return the // appropriate result based on the type of graph they're in. - fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a>; + fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a>; } pub struct Synchronous; impl Synchronicity for Synchronous { - type UpdateResult<'a> = bool; + type UpdateResult<'a> = (); - fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> { - result - } + fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> {} } pub struct Asynchronous; impl Synchronicity for Asynchronous { - type UpdateResult<'a> = Pin + 'a>>; + type UpdateResult<'a> = Pin + 'a>>; - fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> { - Box::pin(std::future::ready(result)) + fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> { + Box::pin(std::future::ready(())) } } diff --git a/crates/compute_graph_macros/src/lib.rs b/crates/compute_graph_macros/src/lib.rs index 75abec2..dc453be 100644 --- a/crates/compute_graph_macros/src/lib.rs +++ b/crates/compute_graph_macros/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro::TokenStream; use proc_macro2::Literal; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, ToTokens}; use syn::{ parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam, PathArguments, Type, @@ -10,8 +10,8 @@ extern crate proc_macro; /// Derive an implementation of the `InputVisitable` trait and helper methods. /// -/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that -/// calls `visit` on each field of the struct that is of type `Input` for any T. +/// This macro generates an implementation of the `InputVisitable` trait and the `visit_inputs` method that +/// calls `visit` on each field of the struct that is of type `Input` or `DynamicInput` for any `T`. /// /// The macro also generates helper methods for accessing the value of each input less verbosely. /// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc. @@ -56,20 +56,34 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { Fields::Named(ref named) => named .named .iter() - .filter(|field| input_value_type(field).is_some()) - .map(|field| { - let ident = field.ident.as_ref().unwrap(); - quote!(visitor.visit(&self.#ident);) + .flat_map(|field| { + if let Some((_ty, is_dynamic)) = input_value_type(field) { + let ident = field.ident.as_ref().unwrap(); + if is_dynamic { + Some(quote!(visitor.visit_dynamic(&self.#ident);)) + } else { + Some(quote!(visitor.visit(&self.#ident);)) + } + } else { + None + } }) .collect::>(), Fields::Unnamed(ref unnamed) => unnamed .unnamed .iter() .enumerate() - .filter(|(_, field)| input_value_type(field).is_some()) - .map(|(i, _)| { - let idx_lit = Literal::usize_unsuffixed(i); - quote!(visitor.visit(&self.#idx_lit);) + .flat_map(|(i, field)| { + if let Some((_ty, is_dynamic)) = input_value_type(field) { + let idx_lit = Literal::usize_unsuffixed(i); + if is_dynamic { + Some(quote!(visitor.visit_dynamic(&self.#idx_lit);)) + } else { + Some(quote!(visitor.visit(&self.#idx_lit);)) + } + } else { + None + } }) .collect::>(), Fields::Unit => vec![], @@ -79,12 +93,19 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { Fields::Named(ref named) => named .named .iter() - .filter_map(|field| input_value_type(field).map(|ty| (field, ty))) - .map(|(field, ty)| { + .filter_map(|field| { + input_value_type(field).map(|(ty, is_dynamic)| (field, ty, is_dynamic)) + }) + .map(|(field, ty, is_dynamic)| { let ident = field.ident.as_ref().unwrap(); + let target = if is_dynamic { + quote!(::compute_graph::node::DynamicRuleOutput<#ty>) + } else { + ty.to_token_stream() + }; quote!( - fn #ident(&self) -> impl ::std::ops::Deref + '_ { + fn #ident(&self) -> impl ::std::ops::Deref + '_ { self.#ident.value() } @@ -95,13 +116,20 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { .unnamed .iter() .enumerate() - .filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty))) - .map(|(i, ty)| { + .filter_map(|(i, field)| { + input_value_type(field).map(|(ty, is_dynamic)| (i, ty, is_dynamic)) + }) + .map(|(i, ty, is_dynamic)| { let idx_lit = Literal::usize_unsuffixed(i); let ident = format_ident!("input_{i}"); + let target = if is_dynamic { + quote!(::compute_graph::node::DynamicRuleOutput<#ty>) + } else { + ty.to_token_stream() + }; quote!( - fn #ident(&self) -> impl ::std::ops::Deref + '_ { + fn #ident(&self) -> impl ::std::ops::Deref + '_ { self.#idx_lit.value() } @@ -126,14 +154,15 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { )) } -fn input_value_type(field: &Field) -> Option<&Type> { +fn input_value_type(field: &Field) -> Option<(&Type, bool)> { if let Type::Path(ref path) = field.ty { let last_segment = path.path.segments.last().unwrap(); - if last_segment.ident == "Input" { + if last_segment.ident == "Input" || last_segment.ident == "DynamicInput" { + let is_dynamic = last_segment.ident == "DynamicInput"; if let PathArguments::AngleBracketed(ref args) = last_segment.arguments { if args.args.len() == 1 { if let GenericArgument::Type(ref ty) = args.args.first().unwrap() { - Some(ty) + Some((ty, is_dynamic)) } else { None } diff --git a/crates/derive_test/src/lib.rs b/crates/derive_test/src/lib.rs index aa9d62b..becb59c 100644 --- a/crates/derive_test/src/lib.rs +++ b/crates/derive_test/src/lib.rs @@ -1,5 +1,5 @@ use compute_graph::node::NodeValue; -use compute_graph::rule::{Input, InputVisitable, Rule}; +use compute_graph::rule::{DynamicInput, Input, InputVisitable, Rule}; #[derive(InputVisitable)] struct Add(Input, Input, i32); @@ -34,9 +34,25 @@ impl Rule for Passthrough { } } +#[derive(InputVisitable)] +struct Sum(DynamicInput); +impl Rule for Sum { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + self.input_0() + .inputs + .iter() + .map(|input| *input.value()) + .sum() + } +} + #[cfg(test)] mod tests { - use compute_graph::builder::GraphBuilder; + use compute_graph::{ + builder::GraphBuilder, + rule::{ConstantRule, DynamicRule}, + }; use super::*; @@ -59,4 +75,27 @@ mod tests { let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), 6); } + + #[test] + fn test_sum() { + #[derive(InputVisitable)] + struct Dynamic; + impl DynamicRule for Dynamic { + type ChildOutput = i32; + fn evaluate( + &mut self, + ctx: &mut impl compute_graph::rule::DynamicRuleContext, + ) -> Vec> { + vec![ + ctx.add_rule(ConstantRule::new(1)), + ctx.add_rule(ConstantRule::new(2)), + ] + } + } + let mut builder = GraphBuilder::new(); + let dynamic_input = builder.add_dynamic_rule(Dynamic); + builder.set_output(Sum(dynamic_input)); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate(), 3); + } }