Dynamic rules

This commit is contained in:
Shadowfacts 2024-12-29 13:37:54 -05:00
parent 9cb6a8c6ce
commit d92ebf11b2
7 changed files with 803 additions and 103 deletions

View File

@ -1,8 +1,8 @@
use crate::node::{ use crate::node::{
AsyncConstNode, AsyncRuleNode, ConstNode, ErasedNode, InvalidatableConstNode, Node, NodeValue, AsyncConstNode, AsyncDynamicRuleNode, AsyncRuleNode, ConstNode, DynamicRuleNode, ErasedNode,
RuleNode, InvalidatableConstNode, Node, NodeValue, RuleNode,
}; };
use crate::rule::{AsyncRule, Input, Rule}; use crate::rule::{AsyncDynamicRule, AsyncRule, DynamicInput, DynamicRule, Input, Rule};
use crate::synchronicity::{Asynchronous, Synchronicity, Synchronous}; use crate::synchronicity::{Asynchronous, Synchronicity, Synchronous};
use crate::util; use crate::util;
use crate::{Graph, InvalidationSignal, NodeGraph, NodeId, ValueInvalidationSignal}; use crate::{Graph, InvalidationSignal, NodeGraph, NodeId, ValueInvalidationSignal};
@ -73,7 +73,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
let erased = ErasedNode::new(node); let erased = ErasedNode::new(node);
let idx = self.node_graph.borrow_mut().add_node(erased); let idx = self.node_graph.borrow_mut().add_node(erased);
Input { Input {
node_idx: idx, node_idx: Rc::new(Cell::new(Some(idx))),
value, value,
} }
} }
@ -174,19 +174,42 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
} }
fn make_invalidation_signal<V>(&self, input: &Input<V>) -> InvalidationSignal { fn make_invalidation_signal<V>(&self, input: &Input<V>) -> InvalidationSignal {
let node_idx = input.node_idx; let node_idx = Rc::clone(&input.node_idx);
let graph = Rc::clone(&self.node_graph); let graph = Rc::clone(&self.node_graph);
let graph_is_valid = Rc::clone(&self.is_valid); let graph_is_valid = Rc::clone(&self.is_valid);
InvalidationSignal { InvalidationSignal {
do_invalidate: Rc::new(Box::new(move || { do_invalidate: Rc::new(Box::new(move || {
graph_is_valid.set(false); graph_is_valid.set(false);
let mut graph = graph.borrow_mut(); let mut graph = graph.borrow_mut();
let node = &mut graph[node_idx]; let node = &mut graph[node_idx.get().unwrap()];
node.invalidate(); node.invalidate();
})), })),
} }
} }
/// Adds a node to the graph whose output is additional nodes produced by the given rule.
pub fn add_dynamic_rule<R>(&mut self, rule: R) -> DynamicInput<R::ChildOutput>
where
R: DynamicRule,
{
let input = self.add_node(DynamicRuleNode::<R, R::ChildOutput, S>::new(rule));
DynamicInput { input }
}
/// Adds an externally-invalidatable node to the graph whose output is additional
/// nodes produced by the given rule.
pub fn add_invalidatable_dynamic_rule<R>(
&mut self,
rule: R,
) -> (DynamicInput<R::ChildOutput>, InvalidationSignal)
where
R: DynamicRule,
{
let input = self.add_dynamic_rule(rule);
let signal = self.make_invalidation_signal(&input.input);
(input, signal)
}
/// Creates a graph from this builder, consuming the builder. /// Creates a graph from this builder, consuming the builder.
/// ///
/// To successfully build a graph, there must be an output node set (using either /// To successfully build a graph, there must be an output node set (using either
@ -217,7 +240,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
graph.add_edge(source, dest, ()); graph.add_edge(source, dest, ());
} }
util::remove_nodes_not_connected_to(&mut *graph, output.node_idx); util::remove_nodes_not_connected_to(&mut *graph, output.node_idx.get().unwrap());
drop(graph); drop(graph);
@ -319,6 +342,29 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
let signal = self.make_invalidation_signal(&input); let signal = self.make_invalidation_signal(&input);
(input, signal) (input, signal)
} }
/// Adds a node to the graph whose output is additional nodes produced asynchronously by the given rule.
pub fn add_async_dynamic_rule<R>(&mut self, rule: R) -> DynamicInput<R::ChildOutput>
where
R: AsyncDynamicRule,
{
let input = self.add_node(AsyncDynamicRuleNode::<R, R::ChildOutput>::new(rule));
DynamicInput { input }
}
/// Adds an externally-invalidatable node to the graph whose output is additional nodes produced
/// asynchronously by the given rule.
pub fn add_invalidatable_async_dynamic_rule<R>(
&mut self,
rule: R,
) -> (DynamicInput<R::ChildOutput>, InvalidationSignal)
where
R: AsyncDynamicRule,
{
let input = self.add_async_dynamic_rule(rule);
let signal = self.make_invalidation_signal(&input.input);
(input, signal)
}
} }
/// A reason why a [`GraphBuilder`] can fail to build a graph. /// A reason why a [`GraphBuilder`] can fail to build a graph.
@ -383,8 +429,18 @@ mod tests {
builder.set_output(Double::new(b.clone())); builder.set_output(Double::new(b.clone()));
match builder.build() { match builder.build() {
Err(super::BuildGraphError::Cycle(cycle)) => { Err(super::BuildGraphError::Cycle(cycle)) => {
let a_start = cycle == vec![a.node_idx, b.node_idx, a.node_idx]; let a_start = cycle
let b_start = cycle == vec![b.node_idx, a.node_idx, b.node_idx]; == vec![
a.node_idx.get().unwrap(),
b.node_idx.get().unwrap(),
a.node_idx.get().unwrap(),
];
let b_start = cycle
== vec![
b.node_idx.get().unwrap(),
a.node_idx.get().unwrap(),
b.node_idx.get().unwrap(),
];
// either is a permisisble way of describing the cycle // either is a permisisble way of describing the cycle
assert!(a_start || b_start); assert!(a_start || b_start);
} }

View File

@ -49,10 +49,10 @@ pub mod synchronicity;
mod util; mod util;
use builder::{BuildGraphError, GraphBuilder}; use builder::{BuildGraphError, GraphBuilder};
use node::{ErasedNode, NodeValue}; use node::{ErasedNode, NodeUpdateContext, NodeValue};
use petgraph::visit::{IntoEdgeReferences, IntoNodeReferences, NodeIndexable, NodeRef}; use petgraph::visit::{IntoEdgeReferences, IntoNodeReferences, NodeIndexable, NodeRef};
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef}; use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
use rule::{AsyncRule, Input, InputVisitor, Rule}; use rule::{Input, InputVisitor};
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::collections::HashMap; use std::collections::HashMap;
use std::collections::VecDeque; use std::collections::VecDeque;
@ -127,7 +127,15 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
/// Because building a graph can fail and this method mutates the underlying graph, it takes /// Because building a graph can fail and this method mutates the underlying graph, it takes
/// ownership of the current graph to prevent the graph being left in an invalid state. /// ownership of the current graph to prevent the graph being left in an invalid state.
/// It returns either the new, modified graph or an error. /// It returns either the new, modified graph or an error.
pub fn modify<F>(mut self, mut f: F) -> Result<Self, BuildGraphError> pub fn modify<F>(mut self, f: F) -> Result<Self, BuildGraphError>
where
F: FnMut(&mut GraphBuilder<O, S>) -> (),
{
self._modify(f)?;
Ok(self)
}
fn _modify<F>(&mut self, mut f: F) -> Result<(), BuildGraphError>
where where
F: FnMut(&mut GraphBuilder<O, S>) -> (), F: FnMut(&mut GraphBuilder<O, S>) -> (),
{ {
@ -142,12 +150,12 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
} }
drop(graph); drop(graph);
let old_output = self.output.node_idx; let old_output = self.output.node_idx.get();
// Modify // Modify
let mut builder = self.into_builder(); let mut builder = self.to_builder();
f(&mut builder); f(&mut builder);
self = builder.build()?; *self = builder.build()?;
// Any new inboud edges invalidate their target nodes. // Any new inboud edges invalidate their target nodes.
let mut graph = self.node_graph.borrow_mut(); let mut graph = self.node_graph.borrow_mut();
@ -164,7 +172,7 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
} }
// Edge case: if the only node in the graph is the output node, and it's replaced in the modify block, // Edge case: if the only node in the graph is the output node, and it's replaced in the modify block,
// there are no edges but we still need to invalidate. // there are no edges but we still need to invalidate.
if !to_invalidate.is_empty() || self.output.node_idx != old_output { if !to_invalidate.is_empty() || self.output.node_idx.get() != old_output {
self.is_valid.set(false); self.is_valid.set(false);
for idx in to_invalidate { for idx in to_invalidate {
let node = &mut graph[idx]; let node = &mut graph[idx];
@ -173,13 +181,17 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
} }
drop(graph); drop(graph);
Ok(self) Ok(())
} }
/// Convert this graph back into a builder for further modifications. /// Convert this graph back into a builder for further modifications.
/// ///
/// Returns a builder with the same output and synchronicity types. /// Returns a builder with the same output and synchronicity types.
pub fn into_builder(self) -> GraphBuilder<O, S> { pub fn into_builder(self) -> GraphBuilder<O, S> {
self.to_builder()
}
fn to_builder(&self) -> GraphBuilder<O, S> {
// Clear the edges before modifying so that rebuilding results in a graph with up-to-date edges. // Clear the edges before modifying so that rebuilding results in a graph with up-to-date edges.
let mut graph = self.node_graph.borrow_mut(); let mut graph = self.node_graph.borrow_mut();
graph.clear_edges(); graph.clear_edges();
@ -232,7 +244,7 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
for node in self.0.node_references() { for node in self.0.node_references() {
let id = self.0.to_index(node.id()); let id = self.0.to_index(node.id());
let label = Escaped(node.weight()); let label = Escaped(node.weight());
writeln!(f, "\t{id} [label =\"{label:?} (id={id})\"]")?; writeln!(f, "\t{id} [label=\"{label:?} (id={id})\"]")?;
} }
for edge in self.0.edge_references() { for edge in self.0.edge_references() {
let source = self.0.to_index(edge.source()); let source = self.0.to_index(edge.source());
@ -250,13 +262,51 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
impl<O: 'static> Graph<O, Synchronous> { impl<O: 'static> Graph<O, Synchronous> {
fn update_invalid_nodes(&mut self) { fn update_invalid_nodes(&mut self) {
let mut graph = self.node_graph.borrow_mut(); let mut graph = self.node_graph.borrow_mut();
for &idx in self.sorted_nodes.iter() { let mut i = 0;
while i < self.sorted_nodes.len() {
let idx = self.sorted_nodes[i];
let node = &mut graph[idx]; let node = &mut graph[idx];
if !node.is_valid() { if !node.is_valid() {
// Update this node // Update this node
let value_changed = node.update(); let mut ctx = NodeUpdateContext::new();
node.update(&mut ctx);
if value_changed { let mut nodes_changed = false;
for idx_to_remove in ctx.removed_nodes {
assert!(
idx_to_remove != idx,
"cannot remove node curently being evaluated"
);
let (index_to_remove_in_sorted, _) = self
.sorted_nodes
.iter()
.enumerate()
.find(|(_, idx)| **idx == idx_to_remove)
.expect("removed node must have been already added");
assert!(
index_to_remove_in_sorted > i,
"cannot remove already evaluated node"
);
graph.remove_node(idx_to_remove);
self.sorted_nodes.remove(index_to_remove_in_sorted);
nodes_changed = true;
}
for (added_node, id_cell) in ctx.added_nodes {
let id = graph.add_node(added_node);
id_cell.set(Some(id));
nodes_changed = true;
}
if nodes_changed {
// Update the graph before invalidating downstream nodes.
drop(graph);
self._modify(|_| {})
.expect("modifying graph during evaluation must produce valid graph");
graph = self.node_graph.borrow_mut();
}
if ctx.invalidate_dependent_nodes {
// Invalidate any downstream nodes (which we know we haven't visited yet, because // Invalidate any downstream nodes (which we know we haven't visited yet, because
// we're iterating over a topological sort of the graph). // we're iterating over a topological sort of the graph).
let dependents = graph let dependents = graph
@ -270,14 +320,25 @@ impl<O: 'static> Graph<O, Synchronous> {
dependent.invalidate(); dependent.invalidate();
} }
} }
if nodes_changed {
// If we added/removed nodes, the sorted order has changed, so start evaluating
// from the beginning, in case of changes before i.
i = 0;
continue;
}
} }
i += 1;
} }
// Consistency check: after updating in the topological sort order, we should be left with // Consistency check: after updating in the topological sort order, we should be left with
// no invalid nodes // no invalid nodes.
debug_assert!(self debug_assert!(self
.sorted_nodes .sorted_nodes
.iter() .iter()
.all(|&idx| { (&graph[idx]).is_valid() })); .all(|&idx| { (&graph[idx]).is_valid() }));
self.is_valid.set(true); self.is_valid.set(true);
} }
@ -300,13 +361,51 @@ impl<O: 'static> Graph<O, Asynchronous> {
async fn update_invalid_nodes(&mut self) { async fn update_invalid_nodes(&mut self) {
// TODO: consider whether this can be done in parallel to any degree. // TODO: consider whether this can be done in parallel to any degree.
let mut graph = self.node_graph.borrow_mut(); let mut graph = self.node_graph.borrow_mut();
for &idx in self.sorted_nodes.iter() { let mut i = 0;
while i < self.sorted_nodes.len() {
let idx = self.sorted_nodes[i];
let node = &mut graph[idx]; let node = &mut graph[idx];
if !node.is_valid() { if !node.is_valid() {
// Update this node // Update this node
let value_changed = node.update().await; let mut ctx = NodeUpdateContext::new();
node.update(&mut ctx).await;
if value_changed { let mut nodes_changed = false;
for idx_to_remove in ctx.removed_nodes {
assert!(
idx_to_remove != idx,
"cannot remove node curently being evaluated"
);
let (index_to_remove_in_sorted, _) = self
.sorted_nodes
.iter()
.enumerate()
.find(|(_, idx)| **idx == idx_to_remove)
.expect("removed node must have been already added");
assert!(
index_to_remove_in_sorted > i,
"cannot remove already evaluated node"
);
graph.remove_node(idx_to_remove);
self.sorted_nodes.remove(index_to_remove_in_sorted);
nodes_changed = true;
}
for (added_node, id_cell) in ctx.added_nodes {
let id = graph.add_node(added_node);
id_cell.set(Some(id));
nodes_changed = true;
}
if nodes_changed {
// Update the graph before invalidating downstream nodes.
drop(graph);
self._modify(|_| {})
.expect("modifying graph during evaluation must produce valid graph");
graph = self.node_graph.borrow_mut();
}
if ctx.invalidate_dependent_nodes {
// Invalidate any downstream nodes (which we know we haven't visited yet, because // Invalidate any downstream nodes (which we know we haven't visited yet, because
// we're iterating over a topological sort of the graph). // we're iterating over a topological sort of the graph).
let dependents = graph let dependents = graph
@ -320,14 +419,25 @@ impl<O: 'static> Graph<O, Asynchronous> {
dependent.invalidate(); dependent.invalidate();
} }
} }
if nodes_changed {
// If we added/removed nodes, the sorted order has changed, so start evaluating
// from the beginning, in case of changes before i.
i = 0;
continue;
}
} }
i += 1;
} }
// Consistency check: after updating in the topological sort order, we should be left with // Consistency check: after updating in the topological sort order, we should be left with
// no invalid nodes // no invalid nodes
debug_assert!(self debug_assert!(self
.sorted_nodes .sorted_nodes
.iter() .iter()
.all(|&idx| { (&graph[idx]).is_valid() })); .all(|&idx| { (&graph[idx]).is_valid() }));
self.is_valid.set(true); self.is_valid.set(true);
} }
@ -420,7 +530,9 @@ impl<V> Clone for ValueInvalidationSignal<V> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::rule::{ConstantRule, InputVisitable}; use crate::rule::{
AsyncDynamicRule, AsyncRule, ConstantRule, DynamicInput, DynamicRule, InputVisitable, Rule,
};
#[test] #[test]
fn rule_output_with_no_inputs() { fn rule_output_with_no_inputs() {
@ -711,13 +823,108 @@ mod tests {
assert_eq!( assert_eq!(
graph.as_dot_string(), graph.as_dot_string(),
r#"digraph { r#"digraph {
0 [label ="ConstNode<i32> (id=0)"] 0 [label="ConstNode<i32> (id=0)"]
1 [label ="ConstNode<i32> (id=1)"] 1 [label="ConstNode<i32> (id=1)"]
2 [label ="RuleNode<compute_graph::tests::graphviz::AddWithLabel>(test) (id=2)"] 2 [label="RuleNode<AddWithLabel>(test) (id=2)"]
0 -> 2 [] 0 -> 2 []
1 -> 2 [] 1 -> 2 []
} }
"# "#
) )
} }
#[test]
fn dynamic_rule() {
let mut builder = GraphBuilder::new();
let (count, set_count) = builder.add_invalidatable_value(1);
struct CountUpTo(Input<i32>, Vec<Input<i32>>);
impl InputVisitable for CountUpTo {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl DynamicRule for CountUpTo {
type ChildOutput = i32;
fn evaluate(
&mut self,
ctx: &mut impl rule::DynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
let count = *self.0.value();
assert!(count >= self.1.len() as i32);
while (self.1.len() as i32) < count {
self.1
.push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1)));
}
self.1.clone()
}
}
let all_inputs = builder.add_dynamic_rule(CountUpTo(count, vec![]));
struct Sum(DynamicInput<i32>);
impl InputVisitable for Sum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit_dynamic(&self.0);
}
}
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.0.value().inputs.iter().map(|i| *i.value()).sum()
}
}
builder.set_output(Sum(all_inputs));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 1);
set_count.set_value(2);
assert_eq!(*graph.evaluate(), 3);
set_count.set_value(4);
assert_eq!(*graph.evaluate(), 10);
println!("{}", graph.as_dot_string());
}
#[tokio::test]
async fn async_dynamic_rule() {
let mut builder = GraphBuilder::new_async();
let (count, set_count) = builder.add_invalidatable_value(1);
struct CountUpTo(Input<i32>, Vec<Input<i32>>);
impl InputVisitable for CountUpTo {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl AsyncDynamicRule for CountUpTo {
type ChildOutput = i32;
async fn evaluate<'a>(
&'a mut self,
ctx: &'a mut impl rule::AsyncDynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
let count = *self.0.value();
assert!(count >= self.1.len() as i32);
while (self.1.len() as i32) < count {
self.1
.push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1)));
}
self.1.clone()
}
}
let all_inputs = builder.add_async_dynamic_rule(CountUpTo(count, vec![]));
struct Sum(DynamicInput<i32>);
impl InputVisitable for Sum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit_dynamic(&self.0);
}
}
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.0.value().inputs.iter().map(|i| *i.value()).sum()
}
}
builder.set_output(Sum(all_inputs));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate_async().await, 1);
set_count.set_value(2);
assert_eq!(*graph.evaluate_async().await, 3);
set_count.set_value(4);
assert_eq!(*graph.evaluate_async().await, 10);
}
} }

View File

@ -1,8 +1,12 @@
use crate::rule::{
AsyncDynamicRule, AsyncDynamicRuleContext, AsyncRule, DynamicInput, DynamicRule,
DynamicRuleContext, InputVisitable, Rule,
};
use crate::synchronicity::{Asynchronous, Synchronicity}; use crate::synchronicity::{Asynchronous, Synchronicity};
use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous}; use crate::{Input, InputVisitor, NodeId, Synchronous};
use quote::ToTokens; use quote::ToTokens;
use std::any::Any; use std::any::Any;
use std::cell::RefCell; use std::cell::{Cell, RefCell};
use std::future::Future; use std::future::Future;
use std::rc::Rc; use std::rc::Rc;
@ -11,10 +15,35 @@ pub(crate) struct ErasedNode<Synch: Synchronicity> {
is_valid: Box<dyn Fn(&Box<dyn Any>) -> bool>, is_valid: Box<dyn Fn(&Box<dyn Any>) -> bool>,
invalidate: Box<dyn Fn(&mut Box<dyn Any>) -> ()>, invalidate: Box<dyn Fn(&mut Box<dyn Any>) -> ()>,
visit_inputs: Box<dyn Fn(&Box<dyn Any>, &mut dyn FnMut(NodeId) -> ()) -> ()>, visit_inputs: Box<dyn Fn(&Box<dyn Any>, &mut dyn FnMut(NodeId) -> ()) -> ()>,
update: Box<dyn for<'a> Fn(&'a mut Box<dyn Any>) -> Synch::UpdateResult<'a>>, update: Box<
dyn for<'a> Fn(
&'a mut Box<dyn Any>,
&'a mut NodeUpdateContext<Synch>,
) -> Synch::UpdateResult<'a>,
>,
debug_fmt: Box<dyn Fn(&Box<dyn Any>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result>, debug_fmt: Box<dyn Fn(&Box<dyn Any>, &mut std::fmt::Formatter<'_>) -> std::fmt::Result>,
} }
pub(crate) struct NodeUpdateContext<Synch: Synchronicity> {
pub(crate) invalidate_dependent_nodes: bool,
pub(crate) removed_nodes: Vec<NodeId>,
pub(crate) added_nodes: Vec<(ErasedNode<Synch>, Rc<Cell<Option<NodeId>>>)>,
}
impl<S: Synchronicity> NodeUpdateContext<S> {
pub(crate) fn new() -> Self {
Self {
invalidate_dependent_nodes: false,
removed_nodes: vec![],
added_nodes: vec![],
}
}
fn invalidate_dependent_nodes(&mut self) {
self.invalidate_dependent_nodes = true;
}
}
impl<S: Synchronicity> ErasedNode<S> { impl<S: Synchronicity> ErasedNode<S> {
pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self { pub(crate) fn new<N: Node<V, S> + 'static, V: NodeValue>(base: N) -> Self {
// i don't love the double boxing, but i'm not sure how else to do this // i don't love the double boxing, but i'm not sure how else to do this
@ -34,9 +63,9 @@ impl<S: Synchronicity> ErasedNode<S> {
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap(); let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
x.visit_inputs(visitor); x.visit_inputs(visitor);
}), }),
update: Box::new(|any| { update: Box::new(|any, ctx| {
let x = any.downcast_mut::<Box<dyn Node<V, S>>>().unwrap(); let x = any.downcast_mut::<Box<dyn Node<V, S>>>().unwrap();
x.update() x.update(ctx)
}), }),
debug_fmt: Box::new(|any, f| { debug_fmt: Box::new(|any, f| {
let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap(); let x = any.downcast_ref::<Box<dyn Node<V, S>>>().unwrap();
@ -57,14 +86,14 @@ impl<S: Synchronicity> ErasedNode<S> {
} }
impl ErasedNode<Synchronous> { impl ErasedNode<Synchronous> {
pub(crate) fn update(&mut self) -> bool { pub(crate) fn update(&mut self, ctx: &mut NodeUpdateContext<Synchronous>) {
(self.update)(&mut self.any) (self.update)(&mut self.any, ctx)
} }
} }
impl ErasedNode<Asynchronous> { impl ErasedNode<Asynchronous> {
pub(crate) async fn update(&mut self) -> bool { pub(crate) async fn update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
(self.update)(&mut self.any).await (self.update)(&mut self.any, ctx).await
} }
} }
@ -78,7 +107,7 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
fn is_valid(&self) -> bool; fn is_valid(&self) -> bool;
fn invalidate(&mut self); fn invalidate(&mut self);
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()); fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ());
fn update(&mut self) -> Synch::UpdateResult<'_>; fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<Synch>) -> Synch::UpdateResult<'a>;
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>; fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
} }
@ -139,7 +168,7 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {} fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update(&mut self) -> S::UpdateResult<'_> { fn update<'a>(&'a mut self, _ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
unreachable!() unreachable!()
} }
@ -181,11 +210,12 @@ impl<V: NodeValue, S: Synchronicity> Node<V, S> for InvalidatableConstNode<V, S>
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {} fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update(&mut self) -> S::UpdateResult<'_> { fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true; self.valid = true;
// This node is only invalidate when node_value_eq between the old/new value is false, // This node is only invalidate when node_value_eq between the old/new value is false,
// so it is always the case that the update method has changed the value. // so it is always the case that the update method has changed the value.
S::make_update_result(true, crate::synchronicity::private::Token) ctx.invalidate_dependent_nodes();
S::make_update_result(crate::synchronicity::private::Token)
} }
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> { fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
@ -217,6 +247,32 @@ impl<R: Rule, S> RuleNode<R, R::Output, S> {
} }
} }
fn visit_inputs<V: InputVisitable>(visitable: &V, visitor: &mut dyn FnMut(NodeId) -> ()) {
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ());
impl<'a> InputVisitor for InputIndexVisitor<'a> {
fn visit<T>(&mut self, input: &Input<T>) {
self.0(input.node_idx.get().unwrap());
}
fn visit_dynamic<T>(&mut self, input: &DynamicInput<T>) {
// Visit the dynamic node itself
self.visit(&input.input);
// And visit all the nodes it produces
let maybe_dynamic_output = input.input.value.borrow();
if let Some(dynamic_output) = maybe_dynamic_output.as_ref() {
for input in dynamic_output.inputs.iter() {
self.visit(input);
}
} else {
// Haven't evaluated the dynamic node for the first time yet.
// Upon doing so, if the nodes it produces change, we'll modify the graph
// and end up back here in the other branch.
}
}
}
visitable.visit_inputs(&mut InputIndexVisitor(visitor));
}
impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> { impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> {
fn is_valid(&self) -> bool { fn is_valid(&self) -> bool {
self.valid self.valid
@ -227,16 +283,10 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
} }
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ()); visit_inputs(&self.rule, visitor);
impl<'a> InputVisitor for InputIndexVisitor<'a> {
fn visit<T>(&mut self, input: &Input<T>) {
self.0(input.node_idx);
}
}
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
} }
fn update(&mut self) -> S::UpdateResult<'_> { fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true; self.valid = true;
let new_value = self.rule.evaluate(); let new_value = self.rule.evaluate();
@ -247,9 +297,10 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
if value_changed { if value_changed {
*value = Some(new_value); *value = Some(new_value);
ctx.invalidate_dependent_nodes();
} }
S::make_update_result(value_changed, crate::synchronicity::private::Token) S::make_update_result(crate::synchronicity::private::Token)
} }
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> { fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
@ -290,12 +341,12 @@ impl<V, P: FnOnce() -> F, F: Future<Output = V>> AsyncConstNode<V, P, F> {
} }
} }
async fn do_update(&mut self) -> bool { async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true; self.valid = true;
let mut provider = None; let mut provider = None;
std::mem::swap(&mut self.provider, &mut provider); std::mem::swap(&mut self.provider, &mut provider);
*self.value.borrow_mut() = Some(provider.unwrap()().await); *self.value.borrow_mut() = Some(provider.unwrap()().await);
true ctx.invalidate_dependent_nodes();
} }
} }
@ -312,8 +363,11 @@ impl<V: NodeValue, P: FnOnce() -> F, F: Future<Output = V>> Node<V, Asynchronous
fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {} fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {}
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> { fn update<'a>(
Box::pin(self.do_update()) &'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
} }
fn value_rc(&self) -> &Rc<RefCell<Option<V>>> { fn value_rc(&self) -> &Rc<RefCell<Option<V>>> {
@ -342,7 +396,7 @@ impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
} }
} }
async fn do_update(&mut self) -> bool { async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true; self.valid = true;
let new_value = self.rule.evaluate().await; let new_value = self.rule.evaluate().await;
@ -353,9 +407,8 @@ impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
if value_changed { if value_changed {
*value = Some(new_value); *value = Some(new_value);
ctx.invalidate_dependent_nodes();
} }
value_changed
} }
} }
@ -369,17 +422,14 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
} }
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ()); visit_inputs(&self.rule, visitor);
impl<'a> InputVisitor for InputIndexVisitor<'a> {
fn visit<T>(&mut self, input: &Input<T>) {
self.0(input.node_idx);
}
}
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
} }
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> { fn update<'a>(
Box::pin(self.do_update()) &'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
} }
fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> { fn value_rc(&self) -> &Rc<RefCell<Option<R::Output>>> {
@ -405,6 +455,236 @@ impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
} }
} }
// todo: better name for this
pub struct DynamicRuleOutput<O> {
pub inputs: Vec<Input<O>>,
}
impl<O: 'static> NodeValue for DynamicRuleOutput<O> {
fn node_value_eq(&self, other: &Self) -> bool {
if self.inputs.len() != other.inputs.len() {
return false;
}
self.inputs
.iter()
.zip(other.inputs.iter())
.all(|(s, o)| s.node_idx == o.node_idx)
}
}
impl<O> std::fmt::Debug for DynamicRuleOutput<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(std::any::type_name::<Self>())
.field("inputs", &self.inputs)
.finish()
}
}
pub(crate) struct DynamicRuleNode<R, O, S> {
rule: R,
valid: bool,
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
synchronicity: std::marker::PhantomData<S>,
}
impl<R, O, S> DynamicRuleNode<R, O, S> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
valid: false,
value: Rc::new(RefCell::new(None)),
synchronicity: std::marker::PhantomData,
}
}
}
impl<R: DynamicRule, S: Synchronicity> Node<DynamicRuleOutput<R::ChildOutput>, S>
for DynamicRuleNode<R, R::ChildOutput, S>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(&'a mut self, ctx: &'a mut NodeUpdateContext<S>) -> S::UpdateResult<'a> {
self.valid = true;
let new_value = DynamicRuleOutput {
inputs: self.rule.evaluate(&mut DynamicRuleUpdateContext(ctx)),
};
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
S::make_update_result(crate::synchronicity::private::Token)
}
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
&self.value
}
}
struct DynamicRuleUpdateContext<'a, Synch: Synchronicity>(&'a mut NodeUpdateContext<Synch>);
impl<'a, S: Synchronicity> DynamicRuleUpdateContext<'a, S> {
fn add_node<V: NodeValue>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
let node_idx = Rc::new(Cell::new(None));
let value = Rc::clone(node.value_rc());
let erased = ErasedNode::new(node);
self.0.added_nodes.push((erased, Rc::clone(&node_idx)));
Input { node_idx, value }
}
}
impl<'a, S: Synchronicity> DynamicRuleContext for DynamicRuleUpdateContext<'a, S> {
fn remove_node(&mut self, id: NodeId) {
self.0.removed_nodes.push(id);
}
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
{
self.add_node(RuleNode::new(rule))
}
}
struct DynamicRuleLabel<'a, R: DynamicRule>(&'a R);
impl<'a, R: DynamicRule> std::fmt::Display for DynamicRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: DynamicRule, O, V> std::fmt::Debug for DynamicRuleNode<R, O, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DynamicRuleNode<{}>({})",
pretty_type_name::<R>(),
DynamicRuleLabel(&self.rule)
)
}
}
pub(crate) struct AsyncDynamicRuleNode<R, O> {
rule: R,
valid: bool,
value: Rc<RefCell<Option<DynamicRuleOutput<O>>>>,
}
impl<R: AsyncDynamicRule> AsyncDynamicRuleNode<R, R::ChildOutput> {
pub(crate) fn new(rule: R) -> Self {
Self {
rule,
valid: false,
value: Rc::new(RefCell::new(None)),
}
}
async fn do_update(&mut self, ctx: &mut NodeUpdateContext<Asynchronous>) {
self.valid = true;
let new_value = DynamicRuleOutput {
inputs: self
.rule
.evaluate(&mut AsyncDynamicRuleUpdateContext(ctx))
.await,
};
let mut value = self.value.borrow_mut();
let value_changed = value
.as_ref()
.map_or(true, |v| !v.node_value_eq(&new_value));
if value_changed {
*value = Some(new_value);
ctx.invalidate_dependent_nodes();
}
}
}
impl<R: AsyncDynamicRule> Node<DynamicRuleOutput<R::ChildOutput>, Asynchronous>
for AsyncDynamicRuleNode<R, R::ChildOutput>
{
fn is_valid(&self) -> bool {
self.valid
}
fn invalidate(&mut self) {
self.valid = false;
}
fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) {
visit_inputs(&self.rule, visitor);
}
fn update<'a>(
&'a mut self,
ctx: &'a mut NodeUpdateContext<Asynchronous>,
) -> <Asynchronous as Synchronicity>::UpdateResult<'a> {
Box::pin(self.do_update(ctx))
}
fn value_rc(&self) -> &Rc<RefCell<Option<DynamicRuleOutput<R::ChildOutput>>>> {
&self.value
}
}
struct AsyncDynamicRuleUpdateContext<'a>(&'a mut NodeUpdateContext<Asynchronous>);
impl<'a> DynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
fn remove_node(&mut self, id: NodeId) {
DynamicRuleUpdateContext(self.0).remove_node(id);
}
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
{
DynamicRuleUpdateContext(self.0).add_rule(rule)
}
}
impl<'a> AsyncDynamicRuleContext for AsyncDynamicRuleUpdateContext<'a> {
fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: AsyncRule,
{
DynamicRuleUpdateContext(self.0).add_node(AsyncRuleNode::new(rule))
}
}
struct AsyncDynamicRuleLabel<'a, R: AsyncDynamicRule>(&'a R);
impl<'a, R: AsyncDynamicRule> std::fmt::Display for AsyncDynamicRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.node_label(f)
}
}
impl<R: AsyncDynamicRule> std::fmt::Debug for AsyncDynamicRuleNode<R, R::ChildOutput> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AsyncDynamicRuleNode<{}>({})",
pretty_type_name::<R>(),
AsyncDynamicRuleLabel(&self.rule)
)
}
}
fn pretty_type_name<T>() -> String { fn pretty_type_name<T>() -> String {
let s = std::any::type_name::<T>(); let s = std::any::type_name::<T>();
let ty = syn::parse_str::<syn::Type>(s).unwrap(); let ty = syn::parse_str::<syn::Type>(s).unwrap();

View File

@ -1,7 +1,7 @@
use crate::node::NodeValue; use crate::node::{DynamicRuleOutput, NodeValue};
use crate::NodeId; use crate::NodeId;
pub use compute_graph_macros::InputVisitable; pub use compute_graph_macros::InputVisitable;
use std::cell::{Ref, RefCell}; use std::cell::{Cell, Ref, RefCell};
use std::future::Future; use std::future::Future;
use std::ops::Deref; use std::ops::Deref;
use std::rc::Rc; use std::rc::Rc;
@ -76,6 +76,75 @@ pub trait AsyncRule: InputVisitable + 'static {
} }
} }
/// A rule whose output is further nodes in the graph.
///
/// Types implementing this rule should track which nodes they previously output and not
/// add additional equivalent nodes (for whatever domain-specific definition of equivalence)
/// on susbequent evaluations.
pub trait DynamicRule: InputVisitable + 'static {
/// The type of the output value of each of the child nodes that this rule produces.
type ChildOutput: NodeValue;
/// Evaluates this rule, producing additional nodes.
///
/// Use the methods on [`DynamicRuleContext`] to add or remove nodes from the graph.
fn evaluate(&mut self, ctx: &mut impl DynamicRuleContext) -> Vec<Input<Self::ChildOutput>>;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
}
/// Facilities for adding/removing nodes in the graph during the update of a [`DynamicRule`].
// todo: better abstracion for this
// something that handles diffing and does the add/remove automatically
pub trait DynamicRuleContext {
/// Removes the node with the given ID from the graph.
///
/// Be careful when removing nodes. Removing a node that is still depended-upon by another node
/// (i.e., is an input in some other node's [`InputVisitable::visit_inputs`]) is an error.
fn remove_node(&mut self, id: NodeId);
/// Adds a node whose value is produced using the given rule to the graph.
///
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule;
}
/// An asynchronous rule whose output is further nodes in the graph.
///
/// See [`DynamicRule`].
pub trait AsyncDynamicRule: InputVisitable + 'static {
/// The type of the output value of each of the child nodes that this rule produces.
type ChildOutput: NodeValue;
/// Evaluates this rule asynchronously, producing additional nodes.
///
/// Use the methods on [`AsyncDynamicRuleContext`] to add or remove nodes from the graph.
fn evaluate<'a>(
&'a mut self,
ctx: &'a mut impl AsyncDynamicRuleContext,
) -> impl Future<Output = Vec<Input<Self::ChildOutput>>> + 'a;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
}
/// Facilities for adding/removing nodes in the graph during the update of an [`AsyncDynamicRule`].
pub trait AsyncDynamicRuleContext: DynamicRuleContext {
/// Adds a node whose value is produced using the given rule to the graph.
///
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: AsyncRule;
}
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited. /// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
/// ///
/// The implementation of this trait can generally be derived using [`derive@InputVisitable`]. /// The implementation of this trait can generally be derived using [`derive@InputVisitable`].
@ -93,13 +162,13 @@ pub trait InputVisitable {
fn visit_inputs(&self, visitor: &mut impl InputVisitor); fn visit_inputs(&self, visitor: &mut impl InputVisitor);
} }
/// A placeholder for the output of one node to be used as an input for another. /// A placeholder for the output of one node, to be used as an input for another.
/// ///
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`). /// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
/// ///
/// Note that this type implements `Clone`, so can be cloned and used as an input for multiple nodes. /// Note that this type implements `Clone`, so can be cloned and used as an input for multiple nodes.
pub struct Input<T> { pub struct Input<T> {
pub(crate) node_idx: NodeId, pub(crate) node_idx: Rc<Cell<Option<NodeId>>>,
pub(crate) value: Rc<RefCell<Option<T>>>, pub(crate) value: Rc<RefCell<Option<T>>>,
} }
@ -119,7 +188,7 @@ impl<T> Input<T> {
impl<T> Clone for Input<T> { impl<T> Clone for Input<T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
node_idx: self.node_idx, node_idx: Rc::clone(&self.node_idx),
value: Rc::clone(&self.value), value: Rc::clone(&self.value),
} }
} }
@ -136,6 +205,25 @@ impl<T> std::fmt::Debug for Input<T> {
} }
} }
/// A placeholder for the output of a dynamic rule node, to be used as an input for another.
///
/// See [`GraphBuilder::add_dynamic_rule`](`crate::builder::GraphBuilder::add_dynamic_rule`).
///
/// A dependency on a dynamic input represents both a dependency on the dynamic node itself,
/// as well as dependencies on each of the nodes that are the output of the dynamic node.
#[derive(Clone)]
pub struct DynamicInput<T> {
pub(crate) input: Input<DynamicRuleOutput<T>>,
}
impl<T> DynamicInput<T> {
/// Retrieves a reference to the current value of the dynamic node (i.e., the set of inputs
/// representing the nodes that are the outputs of the dynamic node).
pub fn value(&self) -> impl Deref<Target = DynamicRuleOutput<T>> + '_ {
self.input.value()
}
}
// TODO: i really want Input to be able to implement Deref somehow // TODO: i really want Input to be able to implement Deref somehow
/// A type that can visit arbitrary [`Input`]s. /// A type that can visit arbitrary [`Input`]s.
@ -145,6 +233,9 @@ impl<T> std::fmt::Debug for Input<T> {
pub trait InputVisitor { pub trait InputVisitor {
/// Visit an input whose value is of type `T`. /// Visit an input whose value is of type `T`.
fn visit<T>(&mut self, input: &Input<T>); fn visit<T>(&mut self, input: &Input<T>);
/// Visit a dynamic input whose child value is of type `T`.
fn visit_dynamic<T>(&mut self, input: &DynamicInput<T>);
} }
/// A simple rule that provides a constant value. /// A simple rule that provides a constant value.

View File

@ -11,7 +11,7 @@ pub(crate) mod private {
pub trait Sealed {} pub trait Sealed {}
impl Sealed for super::Synchronous {} impl Sealed for super::Synchronous {}
impl Sealed for super::Asynchronous {} impl Sealed for super::Asynchronous {}
impl Sealed for bool {} impl Sealed for () {}
impl<'a> Sealed for <super::Asynchronous as super::Synchronicity>::UpdateResult<'a> {} impl<'a> Sealed for <super::Asynchronous as super::Synchronicity>::UpdateResult<'a> {}
pub struct Token; pub struct Token;
} }
@ -20,25 +20,23 @@ pub trait Synchronicity: private::Sealed + 'static {
type UpdateResult<'a>: private::Sealed; type UpdateResult<'a>: private::Sealed;
// Necessary for synchronous nodes that can be part of an async graph to return the // Necessary for synchronous nodes that can be part of an async graph to return the
// appropriate result based on the type of graph they're in. // appropriate result based on the type of graph they're in.
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a>; fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a>;
} }
pub struct Synchronous; pub struct Synchronous;
impl Synchronicity for Synchronous { impl Synchronicity for Synchronous {
type UpdateResult<'a> = bool; type UpdateResult<'a> = ();
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> { fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> {}
result
}
} }
pub struct Asynchronous; pub struct Asynchronous;
impl Synchronicity for Asynchronous { impl Synchronicity for Asynchronous {
type UpdateResult<'a> = Pin<Box<dyn Future<Output = bool> + 'a>>; type UpdateResult<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
fn make_update_result<'a>(result: bool, _: private::Token) -> Self::UpdateResult<'a> { fn make_update_result<'a>(_: private::Token) -> Self::UpdateResult<'a> {
Box::pin(std::future::ready(result)) Box::pin(std::future::ready(()))
} }
} }

View File

@ -1,6 +1,6 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use proc_macro2::Literal; use proc_macro2::Literal;
use quote::{format_ident, quote}; use quote::{format_ident, quote, ToTokens};
use syn::{ use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam, parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam,
PathArguments, Type, PathArguments, Type,
@ -10,8 +10,8 @@ extern crate proc_macro;
/// Derive an implementation of the `InputVisitable` trait and helper methods. /// Derive an implementation of the `InputVisitable` trait and helper methods.
/// ///
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that /// This macro generates an implementation of the `InputVisitable` trait and the `visit_inputs` method that
/// calls `visit` on each field of the struct that is of type `Input<T>` for any T. /// calls `visit` on each field of the struct that is of type `Input<T>` or `DynamicInput<T>` for any `T`.
/// ///
/// The macro also generates helper methods for accessing the value of each input less verbosely. /// The macro also generates helper methods for accessing the value of each input less verbosely.
/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc. /// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc.
@ -56,20 +56,34 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
Fields::Named(ref named) => named Fields::Named(ref named) => named
.named .named
.iter() .iter()
.filter(|field| input_value_type(field).is_some()) .flat_map(|field| {
.map(|field| { if let Some((_ty, is_dynamic)) = input_value_type(field) {
let ident = field.ident.as_ref().unwrap(); let ident = field.ident.as_ref().unwrap();
quote!(visitor.visit(&self.#ident);) if is_dynamic {
Some(quote!(visitor.visit_dynamic(&self.#ident);))
} else {
Some(quote!(visitor.visit(&self.#ident);))
}
} else {
None
}
}) })
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
Fields::Unnamed(ref unnamed) => unnamed Fields::Unnamed(ref unnamed) => unnamed
.unnamed .unnamed
.iter() .iter()
.enumerate() .enumerate()
.filter(|(_, field)| input_value_type(field).is_some()) .flat_map(|(i, field)| {
.map(|(i, _)| { if let Some((_ty, is_dynamic)) = input_value_type(field) {
let idx_lit = Literal::usize_unsuffixed(i); let idx_lit = Literal::usize_unsuffixed(i);
quote!(visitor.visit(&self.#idx_lit);) if is_dynamic {
Some(quote!(visitor.visit_dynamic(&self.#idx_lit);))
} else {
Some(quote!(visitor.visit(&self.#idx_lit);))
}
} else {
None
}
}) })
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
Fields::Unit => vec![], Fields::Unit => vec![],
@ -79,12 +93,19 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
Fields::Named(ref named) => named Fields::Named(ref named) => named
.named .named
.iter() .iter()
.filter_map(|field| input_value_type(field).map(|ty| (field, ty))) .filter_map(|field| {
.map(|(field, ty)| { input_value_type(field).map(|(ty, is_dynamic)| (field, ty, is_dynamic))
})
.map(|(field, ty, is_dynamic)| {
let ident = field.ident.as_ref().unwrap(); let ident = field.ident.as_ref().unwrap();
let target = if is_dynamic {
quote!(::compute_graph::node::DynamicRuleOutput<#ty>)
} else {
ty.to_token_stream()
};
quote!( quote!(
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ { fn #ident(&self) -> impl ::std::ops::Deref<Target = #target> + '_ {
self.#ident.value() self.#ident.value()
} }
@ -95,13 +116,20 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
.unnamed .unnamed
.iter() .iter()
.enumerate() .enumerate()
.filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty))) .filter_map(|(i, field)| {
.map(|(i, ty)| { input_value_type(field).map(|(ty, is_dynamic)| (i, ty, is_dynamic))
})
.map(|(i, ty, is_dynamic)| {
let idx_lit = Literal::usize_unsuffixed(i); let idx_lit = Literal::usize_unsuffixed(i);
let ident = format_ident!("input_{i}"); let ident = format_ident!("input_{i}");
let target = if is_dynamic {
quote!(::compute_graph::node::DynamicRuleOutput<#ty>)
} else {
ty.to_token_stream()
};
quote!( quote!(
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ { fn #ident(&self) -> impl ::std::ops::Deref<Target = #target> + '_ {
self.#idx_lit.value() self.#idx_lit.value()
} }
@ -126,14 +154,15 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
)) ))
} }
fn input_value_type(field: &Field) -> Option<&Type> { fn input_value_type(field: &Field) -> Option<(&Type, bool)> {
if let Type::Path(ref path) = field.ty { if let Type::Path(ref path) = field.ty {
let last_segment = path.path.segments.last().unwrap(); let last_segment = path.path.segments.last().unwrap();
if last_segment.ident == "Input" { if last_segment.ident == "Input" || last_segment.ident == "DynamicInput" {
let is_dynamic = last_segment.ident == "DynamicInput";
if let PathArguments::AngleBracketed(ref args) = last_segment.arguments { if let PathArguments::AngleBracketed(ref args) = last_segment.arguments {
if args.args.len() == 1 { if args.args.len() == 1 {
if let GenericArgument::Type(ref ty) = args.args.first().unwrap() { if let GenericArgument::Type(ref ty) = args.args.first().unwrap() {
Some(ty) Some((ty, is_dynamic))
} else { } else {
None None
} }

View File

@ -1,5 +1,5 @@
use compute_graph::node::NodeValue; use compute_graph::node::NodeValue;
use compute_graph::rule::{Input, InputVisitable, Rule}; use compute_graph::rule::{DynamicInput, Input, InputVisitable, Rule};
#[derive(InputVisitable)] #[derive(InputVisitable)]
struct Add(Input<i32>, Input<i32>, i32); struct Add(Input<i32>, Input<i32>, i32);
@ -34,9 +34,25 @@ impl<T: NodeValue + Clone> Rule for Passthrough<T> {
} }
} }
#[derive(InputVisitable)]
struct Sum(DynamicInput<i32>);
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.input_0()
.inputs
.iter()
.map(|input| *input.value())
.sum()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use compute_graph::builder::GraphBuilder; use compute_graph::{
builder::GraphBuilder,
rule::{ConstantRule, DynamicRule},
};
use super::*; use super::*;
@ -59,4 +75,27 @@ mod tests {
let mut graph = builder.build().unwrap(); let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 6); assert_eq!(*graph.evaluate(), 6);
} }
#[test]
fn test_sum() {
#[derive(InputVisitable)]
struct Dynamic;
impl DynamicRule for Dynamic {
type ChildOutput = i32;
fn evaluate(
&mut self,
ctx: &mut impl compute_graph::rule::DynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
vec![
ctx.add_rule(ConstantRule::new(1)),
ctx.add_rule(ConstantRule::new(2)),
]
}
}
let mut builder = GraphBuilder::new();
let dynamic_input = builder.add_dynamic_rule(Dynamic);
builder.set_output(Sum(dynamic_input));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 3);
}
} }